File size: 7,855 Bytes
5c85f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acfb50c
 
 
 
 
 
 
 
 
 
 
 
5c85f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python
# /// script
# dependencies = [
#   "jax[cuda12]",
#   "equinox",
#   "matplotlib",
#   "jaxtyping",
# ]
# ///
"""Sample spin configurations from a trained Ising Generator checkpoint.

Usage:
    python sample.py --checkpoint model.eqx [--num-samples 16]
                     [--output samples.npy] [--plot] [--seed 0]

How autoregressive sampling works
----------------------------------
The model is trained with a causal (lower-triangular) attention mask, so at
position t the output logits[t] are a function of spins s_0 … s_t only.
We exploit this to sample the full sequence one spin at a time:

  1. Sample s_0 uniformly (the model has no BOS token).
  2. For t = 0, 1, …, L²-2:
       a. Run the full forward pass on the current token buffer.
          Spins at positions > t are still placeholder zeros, but causal
          masking prevents the network from attending to them.
       b. Draw s_{t+1} ~ Categorical(softmax(logits[t])).
       c. Write s_{t+1} into the buffer.

This is O(L⁴) in compute (L² steps × L² attention), which is 1 B ops for a
32×32 lattice.  `jax.lax.scan` compiles the loop body once so subsequent
calls are fast.
"""

import argparse
import functools
from pathlib import Path

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Int

from model import Generator, gen_config, snake_order


# ---------------------------------------------------------------------------
# Checkpoint I/O
# ---------------------------------------------------------------------------

def load_checkpoint(
    path: Path,
    config: dict = gen_config,
    key: jax.Array | None = None,
) -> Generator:
    """Deserialise a Generator from *path*.

    A fresh model is initialised with *config* (weights are immediately
    overwritten), so *key* only needs to be reproducible across calls if you
    care about the random seed used for the skeleton — in practice any key
    works.
    """
    if key is None:
        key = jax.random.PRNGKey(0)
    skeleton = Generator(config=config, key=key)
    return eqx.tree_deserialise_leaves(path, skeleton)


# ---------------------------------------------------------------------------
# Sampling
# ---------------------------------------------------------------------------

def sample_sequence(
    model: Generator,
    key: jax.random.PRNGKey,
) -> Int[Array, " seq_len"]:
    """Autoregressively sample one spin sequence in snake order.

    This function is JAX-traceable and safe to use inside ``jax.vmap`` or
    ``jax.lax.scan``.  JIT-compile it (or wrap in ``sample_batch``) for best
    performance — the first call will take longer due to compilation.
    """
    lattice_size = model.encoder.embedder_block.lattice_size
    seq_len = lattice_size * lattice_size

    def step(carry, t):
        tokens, step_key = carry
        step_key, sample_key = jax.random.split(step_key)
        logits = model({"token_ids": tokens}, enable_dropout=False, key=None)
        # logits[t] → distribution over s_{t+1}
        next_token = jax.random.categorical(sample_key, logits[t])
        tokens = tokens.at[t + 1].set(next_token)
        return (tokens, step_key), None

    key, first_key = jax.random.split(key)
    first_token = jax.random.randint(first_key, shape=(), minval=0, maxval=2)
    tokens = jnp.zeros(seq_len, dtype=jnp.int32).at[0].set(first_token)
    (tokens, _), _ = jax.lax.scan(step, (tokens, key), jnp.arange(seq_len - 1))
    return tokens


@eqx.filter_jit
def sample_batch(
    model: Generator,
    num_samples: int,
    key: jax.random.PRNGKey,
) -> Int[Array, "num_samples seq_len"]:
    """Sample *num_samples* configurations in parallel (vmapped + JIT'd).

    The first call triggers compilation; subsequent calls with the same
    ``num_samples`` reuse the compiled code.
    """
    keys = jax.random.split(key, num_samples)
    return jax.vmap(sample_sequence, in_axes=(None, 0))(model, keys)


# ---------------------------------------------------------------------------
# Grid conversion
# ---------------------------------------------------------------------------

def tokens_to_grid(
    tokens: np.ndarray | Array,
    lattice_size: int,
) -> np.ndarray:
    """Convert a snake-ordered token sequence {0, 1} → a {-1, +1} L×L grid."""
    tokens = np.asarray(tokens)
    rows, cols = snake_order(lattice_size)
    grid = np.empty((lattice_size, lattice_size), dtype=np.int8)
    grid[rows, cols] = (tokens * 2 - 1).astype(np.int8)
    return grid


def tokens_to_grids(
    tokens: np.ndarray | Array,
    lattice_size: int,
) -> np.ndarray:
    """Batch version of ``tokens_to_grid``.  Input shape: (N, L²)."""
    tokens = np.asarray(tokens)
    return np.stack([tokens_to_grid(tokens[i], lattice_size) for i in range(len(tokens))])


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def parse_args():
    p = argparse.ArgumentParser(description="Sample from a trained Ising Generator.")
    p.add_argument("--checkpoint",   type=Path, required=True,
                   help="Path to the .eqx checkpoint file.")
    p.add_argument("--num-samples",  type=int,  default=16)
    p.add_argument("--output",       type=Path, default=None,
                   help="Save sampled {-1,+1} grids as a .npy file (N, L, L).")
    p.add_argument("--plot",         action="store_true",
                   help="Display a grid of sampled configurations with matplotlib.")
    p.add_argument("--seed",         type=int,  default=0)
    return p.parse_args()


def main():
    args = parse_args()

    print(f"Loading checkpoint from {args.checkpoint} …")
    model = load_checkpoint(args.checkpoint)
    lattice_size = model.encoder.embedder_block.lattice_size
    print(f"  lattice_size={lattice_size}, seq_len={lattice_size**2}")

    key = jax.random.PRNGKey(args.seed)
    print(f"Sampling {args.num_samples} configurations "
          f"(compiling on first call) …")

    # Sample in fixed-size micro-batches to avoid OOM on GPU.
    # Changing _MICRO_BATCH triggers recompilation, so keep it constant.
    _MICRO_BATCH = 4
    all_tokens = []
    n_calls = -(-args.num_samples // _MICRO_BATCH)  # ceiling division
    for i in range(n_calls):
        key, subkey = jax.random.split(key)
        batch = sample_batch(model, _MICRO_BATCH, subkey)
        all_tokens.append(np.asarray(batch))
        print(f"  batch {i+1}/{n_calls} done")
    tokens = np.concatenate(all_tokens)[:args.num_samples]

    grids = tokens_to_grids(tokens, lattice_size)   # (N, L, L), values {-1, +1}
    print(f"  shape: {grids.shape}  dtype: {grids.dtype}")
    print(f"  mean magnetization : {grids.mean():.4f}")
    print(f"  mean |magnetization|: {np.abs(grids.mean(axis=(1, 2))).mean():.4f}")

    if args.output is not None:
        np.save(args.output, grids)
        print(f"Saved → {args.output}")

    if args.plot:
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            print("matplotlib not available; skipping plot (pip install matplotlib).")
            return

        cols = min(8, args.num_samples)
        rows = (args.num_samples + cols - 1) // cols
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
        axes = np.array(axes).reshape(-1)
        for i, ax in enumerate(axes):
            if i < len(grids):
                ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest")
            ax.axis("off")
        fig.suptitle(
            f"Sampled Ising configurations  "
            f"(L={lattice_size}, n={args.num_samples})",
            fontsize=10,
        )
        plt.tight_layout()
        plt.show()


if __name__ == "__main__":
    main()