#!/usr/bin/env python # /// script # dependencies = [ # "jax[cuda12]", # "equinox", # "matplotlib", # "jaxtyping", # ] # /// """Sample spin configurations from a trained Ising Generator checkpoint. Usage: python sample.py --checkpoint model.eqx [--num-samples 16] [--output samples.npy] [--plot] [--seed 0] How autoregressive sampling works ---------------------------------- The model is trained with a causal (lower-triangular) attention mask, so at position t the output logits[t] are a function of spins s_0 … s_t only. We exploit this to sample the full sequence one spin at a time: 1. Sample s_0 uniformly (the model has no BOS token). 2. For t = 0, 1, …, L²-2: a. Run the full forward pass on the current token buffer. Spins at positions > t are still placeholder zeros, but causal masking prevents the network from attending to them. b. Draw s_{t+1} ~ Categorical(softmax(logits[t])). c. Write s_{t+1} into the buffer. This is O(L⁴) in compute (L² steps × L² attention), which is 1 B ops for a 32×32 lattice. `jax.lax.scan` compiles the loop body once so subsequent calls are fast. """ import argparse import functools from pathlib import Path import equinox as eqx import jax import jax.numpy as jnp import numpy as np from jaxtyping import Array, Int from model import Generator, gen_config, snake_order # --------------------------------------------------------------------------- # Checkpoint I/O # --------------------------------------------------------------------------- def load_checkpoint( path: Path, config: dict = gen_config, key: jax.Array | None = None, ) -> Generator: """Deserialise a Generator from *path*. A fresh model is initialised with *config* (weights are immediately overwritten), so *key* only needs to be reproducible across calls if you care about the random seed used for the skeleton — in practice any key works. """ if key is None: key = jax.random.PRNGKey(0) skeleton = Generator(config=config, key=key) return eqx.tree_deserialise_leaves(path, skeleton) # --------------------------------------------------------------------------- # Sampling # --------------------------------------------------------------------------- def sample_sequence( model: Generator, key: jax.random.PRNGKey, ) -> Int[Array, " seq_len"]: """Autoregressively sample one spin sequence in snake order. This function is JAX-traceable and safe to use inside ``jax.vmap`` or ``jax.lax.scan``. JIT-compile it (or wrap in ``sample_batch``) for best performance — the first call will take longer due to compilation. """ lattice_size = model.encoder.embedder_block.lattice_size seq_len = lattice_size * lattice_size def step(carry, t): tokens, step_key = carry step_key, sample_key = jax.random.split(step_key) logits = model({"token_ids": tokens}, enable_dropout=False, key=None) # logits[t] → distribution over s_{t+1} next_token = jax.random.categorical(sample_key, logits[t]) tokens = tokens.at[t + 1].set(next_token) return (tokens, step_key), None key, first_key = jax.random.split(key) first_token = jax.random.randint(first_key, shape=(), minval=0, maxval=2) tokens = jnp.zeros(seq_len, dtype=jnp.int32).at[0].set(first_token) (tokens, _), _ = jax.lax.scan(step, (tokens, key), jnp.arange(seq_len - 1)) return tokens @eqx.filter_jit def sample_batch( model: Generator, num_samples: int, key: jax.random.PRNGKey, ) -> Int[Array, "num_samples seq_len"]: """Sample *num_samples* configurations in parallel (vmapped + JIT'd). The first call triggers compilation; subsequent calls with the same ``num_samples`` reuse the compiled code. """ keys = jax.random.split(key, num_samples) return jax.vmap(sample_sequence, in_axes=(None, 0))(model, keys) # --------------------------------------------------------------------------- # Grid conversion # --------------------------------------------------------------------------- def tokens_to_grid( tokens: np.ndarray | Array, lattice_size: int, ) -> np.ndarray: """Convert a snake-ordered token sequence {0, 1} → a {-1, +1} L×L grid.""" tokens = np.asarray(tokens) rows, cols = snake_order(lattice_size) grid = np.empty((lattice_size, lattice_size), dtype=np.int8) grid[rows, cols] = (tokens * 2 - 1).astype(np.int8) return grid def tokens_to_grids( tokens: np.ndarray | Array, lattice_size: int, ) -> np.ndarray: """Batch version of ``tokens_to_grid``. Input shape: (N, L²).""" tokens = np.asarray(tokens) return np.stack([tokens_to_grid(tokens[i], lattice_size) for i in range(len(tokens))]) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser(description="Sample from a trained Ising Generator.") p.add_argument("--checkpoint", type=Path, required=True, help="Path to the .eqx checkpoint file.") p.add_argument("--num-samples", type=int, default=16) p.add_argument("--output", type=Path, default=None, help="Save sampled {-1,+1} grids as a .npy file (N, L, L).") p.add_argument("--plot", action="store_true", help="Display a grid of sampled configurations with matplotlib.") p.add_argument("--seed", type=int, default=0) return p.parse_args() def main(): args = parse_args() print(f"Loading checkpoint from {args.checkpoint} …") model = load_checkpoint(args.checkpoint) lattice_size = model.encoder.embedder_block.lattice_size print(f" lattice_size={lattice_size}, seq_len={lattice_size**2}") key = jax.random.PRNGKey(args.seed) print(f"Sampling {args.num_samples} configurations " f"(compiling on first call) …") # Sample in fixed-size micro-batches to avoid OOM on GPU. # Changing _MICRO_BATCH triggers recompilation, so keep it constant. _MICRO_BATCH = 4 all_tokens = [] n_calls = -(-args.num_samples // _MICRO_BATCH) # ceiling division for i in range(n_calls): key, subkey = jax.random.split(key) batch = sample_batch(model, _MICRO_BATCH, subkey) all_tokens.append(np.asarray(batch)) print(f" batch {i+1}/{n_calls} done") tokens = np.concatenate(all_tokens)[:args.num_samples] grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1} print(f" shape: {grids.shape} dtype: {grids.dtype}") print(f" mean magnetization : {grids.mean():.4f}") print(f" mean |magnetization|: {np.abs(grids.mean(axis=(1, 2))).mean():.4f}") if args.output is not None: np.save(args.output, grids) print(f"Saved → {args.output}") if args.plot: try: import matplotlib.pyplot as plt except ImportError: print("matplotlib not available; skipping plot (pip install matplotlib).") return cols = min(8, args.num_samples) rows = (args.num_samples + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5)) axes = np.array(axes).reshape(-1) for i, ax in enumerate(axes): if i < len(grids): ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest") ax.axis("off") fig.suptitle( f"Sampled Ising configurations " f"(L={lattice_size}, n={args.num_samples})", fontsize=10, ) plt.tight_layout() plt.show() if __name__ == "__main__": main()