Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # /// script | |
| # dependencies = [ | |
| # "jax[cuda12]", | |
| # "equinox", | |
| # "matplotlib", | |
| # "jaxtyping", | |
| # ] | |
| # /// | |
| """Sample spin configurations from a trained Ising Generator checkpoint. | |
| Usage: | |
| python sample.py --checkpoint model.eqx [--num-samples 16] | |
| [--output samples.npy] [--plot] [--seed 0] | |
| How autoregressive sampling works | |
| ---------------------------------- | |
| The model is trained with a causal (lower-triangular) attention mask, so at | |
| position t the output logits[t] are a function of spins s_0 … s_t only. | |
| We exploit this to sample the full sequence one spin at a time: | |
| 1. Sample s_0 uniformly (the model has no BOS token). | |
| 2. For t = 0, 1, …, L²-2: | |
| a. Run the full forward pass on the current token buffer. | |
| Spins at positions > t are still placeholder zeros, but causal | |
| masking prevents the network from attending to them. | |
| b. Draw s_{t+1} ~ Categorical(softmax(logits[t])). | |
| c. Write s_{t+1} into the buffer. | |
| This is O(L⁴) in compute (L² steps × L² attention), which is 1 B ops for a | |
| 32×32 lattice. `jax.lax.scan` compiles the loop body once so subsequent | |
| calls are fast. | |
| """ | |
| import argparse | |
| import functools | |
| from pathlib import Path | |
| import equinox as eqx | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from jaxtyping import Array, Int | |
| from model import Generator, gen_config, snake_order | |
| # --------------------------------------------------------------------------- | |
| # Checkpoint I/O | |
| # --------------------------------------------------------------------------- | |
| def load_checkpoint( | |
| path: Path, | |
| config: dict = gen_config, | |
| key: jax.Array | None = None, | |
| ) -> Generator: | |
| """Deserialise a Generator from *path*. | |
| A fresh model is initialised with *config* (weights are immediately | |
| overwritten), so *key* only needs to be reproducible across calls if you | |
| care about the random seed used for the skeleton — in practice any key | |
| works. | |
| """ | |
| if key is None: | |
| key = jax.random.PRNGKey(0) | |
| skeleton = Generator(config=config, key=key) | |
| return eqx.tree_deserialise_leaves(path, skeleton) | |
| # --------------------------------------------------------------------------- | |
| # Sampling | |
| # --------------------------------------------------------------------------- | |
| def sample_sequence( | |
| model: Generator, | |
| key: jax.random.PRNGKey, | |
| ) -> Int[Array, " seq_len"]: | |
| """Autoregressively sample one spin sequence in snake order. | |
| This function is JAX-traceable and safe to use inside ``jax.vmap`` or | |
| ``jax.lax.scan``. JIT-compile it (or wrap in ``sample_batch``) for best | |
| performance — the first call will take longer due to compilation. | |
| """ | |
| lattice_size = model.encoder.embedder_block.lattice_size | |
| seq_len = lattice_size * lattice_size | |
| def step(carry, t): | |
| tokens, step_key = carry | |
| step_key, sample_key = jax.random.split(step_key) | |
| logits = model({"token_ids": tokens}, enable_dropout=False, key=None) | |
| # logits[t] → distribution over s_{t+1} | |
| next_token = jax.random.categorical(sample_key, logits[t]) | |
| tokens = tokens.at[t + 1].set(next_token) | |
| return (tokens, step_key), None | |
| key, first_key = jax.random.split(key) | |
| first_token = jax.random.randint(first_key, shape=(), minval=0, maxval=2) | |
| tokens = jnp.zeros(seq_len, dtype=jnp.int32).at[0].set(first_token) | |
| (tokens, _), _ = jax.lax.scan(step, (tokens, key), jnp.arange(seq_len - 1)) | |
| return tokens | |
| def sample_batch( | |
| model: Generator, | |
| num_samples: int, | |
| key: jax.random.PRNGKey, | |
| ) -> Int[Array, "num_samples seq_len"]: | |
| """Sample *num_samples* configurations in parallel (vmapped + JIT'd). | |
| The first call triggers compilation; subsequent calls with the same | |
| ``num_samples`` reuse the compiled code. | |
| """ | |
| keys = jax.random.split(key, num_samples) | |
| return jax.vmap(sample_sequence, in_axes=(None, 0))(model, keys) | |
| # --------------------------------------------------------------------------- | |
| # Grid conversion | |
| # --------------------------------------------------------------------------- | |
| def tokens_to_grid( | |
| tokens: np.ndarray | Array, | |
| lattice_size: int, | |
| ) -> np.ndarray: | |
| """Convert a snake-ordered token sequence {0, 1} → a {-1, +1} L×L grid.""" | |
| tokens = np.asarray(tokens) | |
| rows, cols = snake_order(lattice_size) | |
| grid = np.empty((lattice_size, lattice_size), dtype=np.int8) | |
| grid[rows, cols] = (tokens * 2 - 1).astype(np.int8) | |
| return grid | |
| def tokens_to_grids( | |
| tokens: np.ndarray | Array, | |
| lattice_size: int, | |
| ) -> np.ndarray: | |
| """Batch version of ``tokens_to_grid``. Input shape: (N, L²).""" | |
| tokens = np.asarray(tokens) | |
| return np.stack([tokens_to_grid(tokens[i], lattice_size) for i in range(len(tokens))]) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Sample from a trained Ising Generator.") | |
| p.add_argument("--checkpoint", type=Path, required=True, | |
| help="Path to the .eqx checkpoint file.") | |
| p.add_argument("--num-samples", type=int, default=16) | |
| p.add_argument("--output", type=Path, default=None, | |
| help="Save sampled {-1,+1} grids as a .npy file (N, L, L).") | |
| p.add_argument("--plot", action="store_true", | |
| help="Display a grid of sampled configurations with matplotlib.") | |
| p.add_argument("--seed", type=int, default=0) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| print(f"Loading checkpoint from {args.checkpoint} …") | |
| model = load_checkpoint(args.checkpoint) | |
| lattice_size = model.encoder.embedder_block.lattice_size | |
| print(f" lattice_size={lattice_size}, seq_len={lattice_size**2}") | |
| key = jax.random.PRNGKey(args.seed) | |
| print(f"Sampling {args.num_samples} configurations " | |
| f"(compiling on first call) …") | |
| # Sample in fixed-size micro-batches to avoid OOM on GPU. | |
| # Changing _MICRO_BATCH triggers recompilation, so keep it constant. | |
| _MICRO_BATCH = 4 | |
| all_tokens = [] | |
| n_calls = -(-args.num_samples // _MICRO_BATCH) # ceiling division | |
| for i in range(n_calls): | |
| key, subkey = jax.random.split(key) | |
| batch = sample_batch(model, _MICRO_BATCH, subkey) | |
| all_tokens.append(np.asarray(batch)) | |
| print(f" batch {i+1}/{n_calls} done") | |
| tokens = np.concatenate(all_tokens)[:args.num_samples] | |
| grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1} | |
| print(f" shape: {grids.shape} dtype: {grids.dtype}") | |
| print(f" mean magnetization : {grids.mean():.4f}") | |
| print(f" mean |magnetization|: {np.abs(grids.mean(axis=(1, 2))).mean():.4f}") | |
| if args.output is not None: | |
| np.save(args.output, grids) | |
| print(f"Saved → {args.output}") | |
| if args.plot: | |
| try: | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| print("matplotlib not available; skipping plot (pip install matplotlib).") | |
| return | |
| cols = min(8, args.num_samples) | |
| rows = (args.num_samples + cols - 1) // cols | |
| fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5)) | |
| axes = np.array(axes).reshape(-1) | |
| for i, ax in enumerate(axes): | |
| if i < len(grids): | |
| ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest") | |
| ax.axis("off") | |
| fig.suptitle( | |
| f"Sampled Ising configurations " | |
| f"(L={lattice_size}, n={args.num_samples})", | |
| fontsize=10, | |
| ) | |
| plt.tight_layout() | |
| plt.show() | |
| if __name__ == "__main__": | |
| main() | |