Spaces:
Sleeping
Sleeping
File size: 7,855 Bytes
5c85f22 acfb50c 5c85f22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | #!/usr/bin/env python
# /// script
# dependencies = [
# "jax[cuda12]",
# "equinox",
# "matplotlib",
# "jaxtyping",
# ]
# ///
"""Sample spin configurations from a trained Ising Generator checkpoint.
Usage:
python sample.py --checkpoint model.eqx [--num-samples 16]
[--output samples.npy] [--plot] [--seed 0]
How autoregressive sampling works
----------------------------------
The model is trained with a causal (lower-triangular) attention mask, so at
position t the output logits[t] are a function of spins s_0 … s_t only.
We exploit this to sample the full sequence one spin at a time:
1. Sample s_0 uniformly (the model has no BOS token).
2. For t = 0, 1, …, L²-2:
a. Run the full forward pass on the current token buffer.
Spins at positions > t are still placeholder zeros, but causal
masking prevents the network from attending to them.
b. Draw s_{t+1} ~ Categorical(softmax(logits[t])).
c. Write s_{t+1} into the buffer.
This is O(L⁴) in compute (L² steps × L² attention), which is 1 B ops for a
32×32 lattice. `jax.lax.scan` compiles the loop body once so subsequent
calls are fast.
"""
import argparse
import functools
from pathlib import Path
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Int
from model import Generator, gen_config, snake_order
# ---------------------------------------------------------------------------
# Checkpoint I/O
# ---------------------------------------------------------------------------
def load_checkpoint(
path: Path,
config: dict = gen_config,
key: jax.Array | None = None,
) -> Generator:
"""Deserialise a Generator from *path*.
A fresh model is initialised with *config* (weights are immediately
overwritten), so *key* only needs to be reproducible across calls if you
care about the random seed used for the skeleton — in practice any key
works.
"""
if key is None:
key = jax.random.PRNGKey(0)
skeleton = Generator(config=config, key=key)
return eqx.tree_deserialise_leaves(path, skeleton)
# ---------------------------------------------------------------------------
# Sampling
# ---------------------------------------------------------------------------
def sample_sequence(
model: Generator,
key: jax.random.PRNGKey,
) -> Int[Array, " seq_len"]:
"""Autoregressively sample one spin sequence in snake order.
This function is JAX-traceable and safe to use inside ``jax.vmap`` or
``jax.lax.scan``. JIT-compile it (or wrap in ``sample_batch``) for best
performance — the first call will take longer due to compilation.
"""
lattice_size = model.encoder.embedder_block.lattice_size
seq_len = lattice_size * lattice_size
def step(carry, t):
tokens, step_key = carry
step_key, sample_key = jax.random.split(step_key)
logits = model({"token_ids": tokens}, enable_dropout=False, key=None)
# logits[t] → distribution over s_{t+1}
next_token = jax.random.categorical(sample_key, logits[t])
tokens = tokens.at[t + 1].set(next_token)
return (tokens, step_key), None
key, first_key = jax.random.split(key)
first_token = jax.random.randint(first_key, shape=(), minval=0, maxval=2)
tokens = jnp.zeros(seq_len, dtype=jnp.int32).at[0].set(first_token)
(tokens, _), _ = jax.lax.scan(step, (tokens, key), jnp.arange(seq_len - 1))
return tokens
@eqx.filter_jit
def sample_batch(
model: Generator,
num_samples: int,
key: jax.random.PRNGKey,
) -> Int[Array, "num_samples seq_len"]:
"""Sample *num_samples* configurations in parallel (vmapped + JIT'd).
The first call triggers compilation; subsequent calls with the same
``num_samples`` reuse the compiled code.
"""
keys = jax.random.split(key, num_samples)
return jax.vmap(sample_sequence, in_axes=(None, 0))(model, keys)
# ---------------------------------------------------------------------------
# Grid conversion
# ---------------------------------------------------------------------------
def tokens_to_grid(
tokens: np.ndarray | Array,
lattice_size: int,
) -> np.ndarray:
"""Convert a snake-ordered token sequence {0, 1} → a {-1, +1} L×L grid."""
tokens = np.asarray(tokens)
rows, cols = snake_order(lattice_size)
grid = np.empty((lattice_size, lattice_size), dtype=np.int8)
grid[rows, cols] = (tokens * 2 - 1).astype(np.int8)
return grid
def tokens_to_grids(
tokens: np.ndarray | Array,
lattice_size: int,
) -> np.ndarray:
"""Batch version of ``tokens_to_grid``. Input shape: (N, L²)."""
tokens = np.asarray(tokens)
return np.stack([tokens_to_grid(tokens[i], lattice_size) for i in range(len(tokens))])
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(description="Sample from a trained Ising Generator.")
p.add_argument("--checkpoint", type=Path, required=True,
help="Path to the .eqx checkpoint file.")
p.add_argument("--num-samples", type=int, default=16)
p.add_argument("--output", type=Path, default=None,
help="Save sampled {-1,+1} grids as a .npy file (N, L, L).")
p.add_argument("--plot", action="store_true",
help="Display a grid of sampled configurations with matplotlib.")
p.add_argument("--seed", type=int, default=0)
return p.parse_args()
def main():
args = parse_args()
print(f"Loading checkpoint from {args.checkpoint} …")
model = load_checkpoint(args.checkpoint)
lattice_size = model.encoder.embedder_block.lattice_size
print(f" lattice_size={lattice_size}, seq_len={lattice_size**2}")
key = jax.random.PRNGKey(args.seed)
print(f"Sampling {args.num_samples} configurations "
f"(compiling on first call) …")
# Sample in fixed-size micro-batches to avoid OOM on GPU.
# Changing _MICRO_BATCH triggers recompilation, so keep it constant.
_MICRO_BATCH = 4
all_tokens = []
n_calls = -(-args.num_samples // _MICRO_BATCH) # ceiling division
for i in range(n_calls):
key, subkey = jax.random.split(key)
batch = sample_batch(model, _MICRO_BATCH, subkey)
all_tokens.append(np.asarray(batch))
print(f" batch {i+1}/{n_calls} done")
tokens = np.concatenate(all_tokens)[:args.num_samples]
grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1}
print(f" shape: {grids.shape} dtype: {grids.dtype}")
print(f" mean magnetization : {grids.mean():.4f}")
print(f" mean |magnetization|: {np.abs(grids.mean(axis=(1, 2))).mean():.4f}")
if args.output is not None:
np.save(args.output, grids)
print(f"Saved → {args.output}")
if args.plot:
try:
import matplotlib.pyplot as plt
except ImportError:
print("matplotlib not available; skipping plot (pip install matplotlib).")
return
cols = min(8, args.num_samples)
rows = (args.num_samples + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
axes = np.array(axes).reshape(-1)
for i, ax in enumerate(axes):
if i < len(grids):
ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest")
ax.axis("off")
fig.suptitle(
f"Sampled Ising configurations "
f"(L={lattice_size}, n={args.num_samples})",
fontsize=10,
)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()
|