bertran-yorro commited on
Commit
5c85f22
Β·
verified Β·
1 Parent(s): 2ef4436

Initial upload: model, training scripts, Gradio app, data

Browse files
Files changed (14) hide show
  1. app.py +287 -0
  2. eval.py +386 -0
  3. examples.png +0 -0
  4. ising.py +20 -0
  5. main.py +26 -0
  6. metadata.json +17 -0
  7. model.py +362 -0
  8. requirements.txt +9 -0
  9. sample.py +208 -0
  10. samples-2-epoch.png +0 -0
  11. spins.npy +3 -0
  12. spins_test.npy +3 -0
  13. train.py +229 -0
  14. vi_train.py +288 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shlex
2
+ import subprocess
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import matplotlib
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Paths
14
+ # ---------------------------------------------------------------------------
15
+
16
+ OUT = Path("outputs")
17
+ CE_CKPT = OUT / "ce_checkpoint.eqx"
18
+ VI_CKPT = OUT / "vi_checkpoint.eqx"
19
+ CE_SMPL = OUT / "samples_ce.npy"
20
+ VI_SMPL = OUT / "samples_vi.npy"
21
+ TRAIN_DATA = Path("spins.npy")
22
+ TEST_DATA = Path("spins_test.npy")
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Subprocess helpers
27
+ # ---------------------------------------------------------------------------
28
+
29
+ def _stream(command: list[str]):
30
+ """Run a command and yield log lines in real time."""
31
+ log = ["$ " + " ".join(shlex.quote(p) for p in command), ""]
32
+ yield "\n".join(log)
33
+ proc = subprocess.Popen(
34
+ command,
35
+ stdout=subprocess.PIPE,
36
+ stderr=subprocess.STDOUT,
37
+ text=True,
38
+ bufsize=1,
39
+ )
40
+ assert proc.stdout is not None
41
+ for line in proc.stdout:
42
+ log.append(line.rstrip())
43
+ yield "\n".join(log[-300:])
44
+ rc = proc.wait()
45
+ log += ["", f"β€” exited {rc} β€”"]
46
+ yield "\n".join(log[-300:])
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Sample grid figure
51
+ # ---------------------------------------------------------------------------
52
+
53
+ def _samples_figure(path: Path, title: str, n: int = 16) -> plt.Figure | None:
54
+ if not path.exists():
55
+ return None
56
+ grids = np.load(path).astype(np.float32)[:n] # (N, L, L), values Β±1
57
+ cols = min(8, len(grids))
58
+ rows = (len(grids) + cols - 1) // cols
59
+ fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.4, rows * 1.4))
60
+ axes = np.array(axes).reshape(-1)
61
+ mags = grids.mean(axis=(1, 2))
62
+ for i, ax in enumerate(axes):
63
+ if i < len(grids):
64
+ ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest")
65
+ ax.set_title(f"m={mags[i]:.2f}", fontsize=6)
66
+ ax.axis("off")
67
+ fig.suptitle(title, fontsize=9)
68
+ plt.tight_layout()
69
+ return fig
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Tab 1 – Cross-entropy training
74
+ # ---------------------------------------------------------------------------
75
+
76
+ def run_ce(mode, epochs, batch_size, lr, max_steps):
77
+ OUT.mkdir(parents=True, exist_ok=True)
78
+ cmd = [
79
+ sys.executable, "train.py",
80
+ "--data", str(TRAIN_DATA),
81
+ "--batch-size", str(int(batch_size)),
82
+ "--learning-rate", str(lr),
83
+ "--output-checkpoint", str(CE_CKPT),
84
+ ]
85
+ if mode == "Smoke":
86
+ cmd += ["--epochs", "1", "--max-train-steps", "5", "--max-eval-batches", "2"]
87
+ else:
88
+ cmd += ["--epochs", str(int(epochs))]
89
+ if int(max_steps) > 0:
90
+ cmd += ["--max-train-steps", str(int(max_steps))]
91
+ for log in _stream(cmd):
92
+ yield log, None
93
+ ckpt = str(CE_CKPT) if CE_CKPT.exists() else None
94
+ yield log, ckpt
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Tab 2 – Variational inference fine-tuning
99
+ # ---------------------------------------------------------------------------
100
+
101
+ def run_vi(mode, steps, batch_size, lr, warm_start):
102
+ OUT.mkdir(parents=True, exist_ok=True)
103
+ if warm_start and not CE_CKPT.exists():
104
+ yield "⚠ CE checkpoint not found. Run CE training first, or uncheck warm-start.", None
105
+ return
106
+ cmd = [
107
+ sys.executable, "vi_train.py",
108
+ "--batch-size", str(int(batch_size)),
109
+ "--learning-rate", str(lr),
110
+ "--output-checkpoint", str(VI_CKPT),
111
+ ]
112
+ if warm_start and CE_CKPT.exists():
113
+ cmd += ["--checkpoint", str(CE_CKPT)]
114
+ if mode == "Smoke":
115
+ cmd += ["--num-steps", "3", "--log-every", "1"]
116
+ else:
117
+ cmd += ["--num-steps", str(int(steps))]
118
+ for log in _stream(cmd):
119
+ yield log, None
120
+ ckpt = str(VI_CKPT) if VI_CKPT.exists() else None
121
+ yield log, ckpt
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Tab 3 – Sample & Eval
126
+ # ---------------------------------------------------------------------------
127
+
128
+ def run_eval(which, num_samples, seed):
129
+ OUT.mkdir(parents=True, exist_ok=True)
130
+ log_lines = []
131
+
132
+ def emit(msg=""):
133
+ log_lines.append(msg)
134
+ return (
135
+ "\n".join(log_lines[-200:]),
136
+ None, None, # CE figure, VI figure
137
+ )
138
+
139
+ run_ce_ = which in ("CE", "Both")
140
+ run_vi_ = which in ("VI", "Both")
141
+
142
+ if run_ce_ and not CE_CKPT.exists():
143
+ yield emit("⚠ CE checkpoint not found. Run CE training first.")
144
+ return
145
+ if run_vi_ and not VI_CKPT.exists():
146
+ yield emit("⚠ VI checkpoint not found. Run VI training first.")
147
+ return
148
+
149
+ # ── Generate samples ───────────────────────────────────────────────────
150
+ for ckpt, out_path, label in [
151
+ (CE_CKPT, CE_SMPL, "CE"),
152
+ (VI_CKPT, VI_SMPL, "VI"),
153
+ ]:
154
+ if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
155
+ continue
156
+ log_lines.append(f"\n── Generating {num_samples} {label} samples ──")
157
+ yield "\n".join(log_lines[-200:]), None, None
158
+ cmd = [
159
+ sys.executable, "sample.py",
160
+ "--checkpoint", str(ckpt),
161
+ "--num-samples", str(int(num_samples)),
162
+ "--output", str(out_path),
163
+ "--seed", str(int(seed)),
164
+ ]
165
+ for chunk in _stream(cmd):
166
+ log_lines[-1:] = chunk.splitlines()[-10:]
167
+ yield "\n".join(log_lines[-200:]), None, None
168
+
169
+ # ── Run eval ──────────────────────────────────────────────────────────
170
+ for ckpt, smpl, label in [
171
+ (CE_CKPT, CE_SMPL, "CE"),
172
+ (VI_CKPT, VI_SMPL, "VI"),
173
+ ]:
174
+ if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
175
+ continue
176
+ log_lines.append(f"\n── Evaluating {label} model ──")
177
+ yield "\n".join(log_lines[-200:]), None, None
178
+ cmd = [
179
+ sys.executable, "eval.py",
180
+ "--checkpoint", str(ckpt),
181
+ "--test-data", str(TEST_DATA),
182
+ "--num-samples", str(int(num_samples)),
183
+ "--samples-file",str(smpl),
184
+ "--seed", str(int(seed)),
185
+ ]
186
+ for chunk in _stream(cmd):
187
+ log_lines[-1:] = chunk.splitlines()[-20:]
188
+ yield "\n".join(log_lines[-200:]), None, None
189
+
190
+ # ── Build figures ──────────────────────────────────────────────────────
191
+ ce_fig = _samples_figure(CE_SMPL, "CE samples") if run_ce_ else None
192
+ vi_fig = _samples_figure(VI_SMPL, "VI samples") if run_vi_ else None
193
+ yield "\n".join(log_lines[-200:]), ce_fig, vi_fig
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # Gradio UI
198
+ # ---------------------------------------------------------------------------
199
+
200
+ with gr.Blocks(title="Ising Transformer", theme=gr.themes.Soft()) as demo:
201
+ gr.Markdown(
202
+ "# 2D Ising Transformer\n"
203
+ "Autoregressive transformer trained on the 2D Ising model at the critical "
204
+ "temperature T_c β‰ˆ 2.27. "
205
+ "Run **CE training** first, optionally fine-tune with **Variational Inference**, "
206
+ "then **Sample & Eval** to compare both against the held-out test set."
207
+ )
208
+
209
+ with gr.Tabs():
210
+
211
+ # ── Tab 1: CE training ──────────────────────────────────────────────
212
+ with gr.Tab("Cross-Entropy Training"):
213
+ gr.Markdown(
214
+ "Trains the model to maximise `log q(s)` on the training spin "
215
+ "configurations (teacher forcing, causal attention). "
216
+ "A *Smoke* run does 5 steps to verify everything compiles."
217
+ )
218
+ with gr.Row():
219
+ ce_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode")
220
+ ce_epoch = gr.Number(value=10, precision=0, minimum=1, label="Epochs")
221
+ ce_bs = gr.Number(value=32, precision=0, minimum=1, label="Batch size")
222
+ with gr.Row():
223
+ ce_lr = gr.Number(value=1e-4, label="Learning rate")
224
+ ce_steps = gr.Number(value=0, precision=0, minimum=0, label="Max steps (0 = no cap)")
225
+ ce_run = gr.Button("Run CE Training", variant="primary")
226
+ ce_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
227
+ ce_ckpt = gr.File(label="Checkpoint")
228
+ ce_run.click(
229
+ run_ce,
230
+ inputs=[ce_mode, ce_epoch, ce_bs, ce_lr, ce_steps],
231
+ outputs=[ce_logs, ce_ckpt],
232
+ )
233
+
234
+ # ── Tab 2: VI fine-tuning ─────────────────────────────────────────
235
+ with gr.Tab("Variational Inference Fine-tuning"):
236
+ gr.Markdown(
237
+ "Minimises the variational free energy `F = ⟨E(s)⟩ βˆ’ TΒ·H[q]` using "
238
+ "the REINFORCE gradient estimator. Warm-starting from the CE "
239
+ "checkpoint is strongly recommended. "
240
+ "A *Smoke* run does 3 steps."
241
+ )
242
+ with gr.Row():
243
+ vi_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode")
244
+ vi_steps = gr.Number(value=200, precision=0, minimum=1, label="Steps")
245
+ vi_bs = gr.Number(value=16, precision=0, minimum=1, label="Batch size")
246
+ with gr.Row():
247
+ vi_lr = gr.Number(value=1e-4, label="Learning rate")
248
+ vi_warm = gr.Checkbox(value=True, label="Warm-start from CE checkpoint")
249
+ vi_run = gr.Button("Run VI Fine-tuning", variant="primary")
250
+ vi_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
251
+ vi_ckpt = gr.File(label="Checkpoint")
252
+ vi_run.click(
253
+ run_vi,
254
+ inputs=[vi_mode, vi_steps, vi_bs, vi_lr, vi_warm],
255
+ outputs=[vi_logs, vi_ckpt],
256
+ )
257
+
258
+ # ── Tab 3: Sample & Eval ──────────────────────────────────────────
259
+ with gr.Tab("Sample & Eval"):
260
+ gr.Markdown(
261
+ "Generates spin configurations from the selected model(s), then runs "
262
+ "the physical-observable evaluation against `spins_test.npy` "
263
+ "(a held-out set, never seen during training).\n\n"
264
+ "Features compared: magnetisation, energy, two-point correlations, "
265
+ "cluster statistics. Distance reported as **Mahalanobis D** in the "
266
+ "decorrelated feature space."
267
+ )
268
+ with gr.Row():
269
+ ev_which = gr.Radio(
270
+ ["CE", "VI", "Both"], value="Both", label="Model(s) to evaluate"
271
+ )
272
+ ev_n = gr.Number(value=64, precision=0, minimum=4, label="Num samples")
273
+ ev_seed = gr.Number(value=0, precision=0, label="Seed")
274
+ ev_run = gr.Button("Run Sample & Eval", variant="primary")
275
+ ev_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
276
+ with gr.Row():
277
+ ev_ce_img = gr.Plot(label="CE samples")
278
+ ev_vi_img = gr.Plot(label="VI samples")
279
+ ev_run.click(
280
+ run_eval,
281
+ inputs=[ev_which, ev_n, ev_seed],
282
+ outputs=[ev_logs, ev_ce_img, ev_vi_img],
283
+ )
284
+
285
+
286
+ if __name__ == "__main__":
287
+ demo.queue(default_concurrency_limit=1).launch()
eval.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # /// script
3
+ # dependencies = [
4
+ # "jax[cuda12]",
5
+ # "equinox",
6
+ # "scipy",
7
+ # "jaxtyping",
8
+ # ]
9
+ # ///
10
+ """Evaluate a trained Generator against held-out test samples.
11
+
12
+ For each configuration we compute an 11-dimensional feature vector of physical
13
+ observables. The Mahalanobis distance between the real and generated feature
14
+ distributions gives a single scalar measure of model quality.
15
+
16
+ Per-sample feature vector
17
+ --------------------------
18
+ m, m^2, |m| magnetisation and its moments
19
+ e, e^2 nearest-neighbour energy per spin (periodic BC)
20
+ C(1..8) connected two-point correlation at r = 1, 2, 4, 8
21
+ s_mean/N mean cluster size (4-connected, open BC)
22
+ s_max/N largest cluster size
23
+
24
+ Ensemble statistics (printed for reference, not part of Mahalanobis)
25
+ ----------------------------------------------------------------------
26
+ chi = N Β· Var(m) / T magnetic susceptibility
27
+ C_v = N Β· Var(e) / TΒ² specific heat
28
+ U4 = 1 βˆ’ <m^4>/(3<m^2>^2) Binder cumulant
29
+ β†’ 2/3 in ordered phase
30
+ β†’ 0 in disordered phase
31
+ β‰ˆ 0.47 at T_c for 2D Ising (Lβ†’βˆž)
32
+
33
+ Distance
34
+ --------
35
+ D = sqrt( Δμ^T Ξ£_real^{-1} Δμ )
36
+ where Δμ = ΞΌ_gen βˆ’ ΞΌ_real and Ξ£_real is the sample covariance of the
37
+ real test features. Per-feature z-scores Δμ_i / Οƒ_real_i are also
38
+ reported so you can see which observables deviate most.
39
+ """
40
+
41
+ import argparse
42
+ from pathlib import Path
43
+
44
+ import numpy as np
45
+ import scipy.ndimage
46
+ import jax
47
+ from tqdm.auto import tqdm
48
+
49
+ from model import gen_config
50
+ from sample import load_checkpoint, sample_batch, tokens_to_grids
51
+ from train import load_ising_data
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Physical constants
56
+ # ---------------------------------------------------------------------------
57
+
58
+ J = 1.0
59
+ T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact: 2J / ln(1+√2) β‰ˆ 2.2692
60
+
61
+ FEATURE_NAMES = [
62
+ "m", "m^2", "|m|",
63
+ "e", "e^2",
64
+ "C(r=1)", "C(r=2)", "C(r=4)", "C(r=8)",
65
+ "s_mean/N", "s_max/N",
66
+ ]
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Per-sample observables
71
+ # ---------------------------------------------------------------------------
72
+
73
+ def energy_per_spin(grid: np.ndarray) -> float:
74
+ """Nearest-neighbour energy density with periodic boundary conditions.
75
+
76
+ E/N = βˆ’J/N Β· Ξ£_{⟨ij⟩} s_i s_j
77
+ Each bond counted once via right- and down-shifts.
78
+ """
79
+ right = np.roll(grid, -1, axis=1)
80
+ down = np.roll(grid, -1, axis=0)
81
+ return float(-J * (grid * right + grid * down).sum() / grid.size)
82
+
83
+
84
+ def connected_correlations(
85
+ grid: np.ndarray,
86
+ distances: tuple[int, ...] = (1, 2, 4, 8),
87
+ ) -> np.ndarray:
88
+ """Isotropic connected two-point function C(r) = Β½[<s_x s_{x+r}> + <s_y s_{y+r}>] - <s>Β².
89
+
90
+ Averaged over both spatial directions and all origin sites using
91
+ periodic boundary conditions.
92
+ """
93
+ m = float(grid.mean())
94
+ corr = []
95
+ for r in distances:
96
+ cx = float((grid * np.roll(grid, r, axis=1)).mean())
97
+ cy = float((grid * np.roll(grid, r, axis=0)).mean())
98
+ corr.append((cx + cy) / 2.0 - m ** 2)
99
+ return np.array(corr, dtype=np.float64)
100
+
101
+
102
+ def cluster_stats(grid: np.ndarray) -> tuple[float, float]:
103
+ """Mean and maximum cluster size for both spin species.
104
+
105
+ Uses 4-connectivity (no diagonals) and open boundary conditions.
106
+ Returns sizes normalised by the total number of spins so the result
107
+ is independent of lattice size.
108
+
109
+ Note: open BC means edge-spanning clusters are split at the boundary;
110
+ this is applied consistently to both real and generated samples so
111
+ systematic bias cancels in the Mahalanobis comparison.
112
+ """
113
+ N = grid.size
114
+ all_sizes: list[np.ndarray] = []
115
+ for spin in (1, -1):
116
+ labeled, n_labels = scipy.ndimage.label(grid == spin)
117
+ if n_labels > 0:
118
+ # bincount index 0 is background; skip it
119
+ all_sizes.append(np.bincount(labeled.ravel())[1:])
120
+ if not all_sizes:
121
+ return 0.0, 0.0
122
+ sizes = np.concatenate(all_sizes).astype(np.float64)
123
+ return float(sizes.mean()) / N, float(sizes.max()) / N
124
+
125
+
126
+ def compute_features(grid: np.ndarray) -> np.ndarray:
127
+ """Return the 11-D feature vector for a single Β±1 grid of shape (L, L)."""
128
+ m = float(grid.mean())
129
+ e = energy_per_spin(grid)
130
+ cr = connected_correlations(grid)
131
+ s_mean, s_max = cluster_stats(grid)
132
+ return np.array(
133
+ [m, m ** 2, abs(m), e, e ** 2, *cr, s_mean, s_max],
134
+ dtype=np.float64,
135
+ )
136
+
137
+
138
+ def compute_feature_matrix(grids: np.ndarray, desc: str = "features") -> np.ndarray:
139
+ """Compute the (N, 11) feature matrix for a batch of grids."""
140
+ return np.stack(
141
+ [compute_features(grids[i])
142
+ for i in tqdm(range(len(grids)), desc=desc, leave=False)]
143
+ )
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # Ensemble statistics
148
+ # ---------------------------------------------------------------------------
149
+
150
+ def ensemble_stats(X: np.ndarray, T: float = T_C) -> dict[str, float]:
151
+ """Derive thermodynamic ensemble statistics from a feature matrix.
152
+
153
+ Arguments
154
+ ---------
155
+ X : (N, 11) feature matrix from ``compute_feature_matrix``.
156
+ T : temperature used for Ο‡ and C_v normalisation.
157
+ """
158
+ L = gen_config["lattice_size"]
159
+ N = L * L
160
+
161
+ m = X[:, FEATURE_NAMES.index("m")]
162
+ m2 = X[:, FEATURE_NAMES.index("m^2")]
163
+ m4 = m ** 4
164
+ e = X[:, FEATURE_NAMES.index("e")]
165
+
166
+ chi = N * float(m.var()) / T
167
+ Cv = N * float(e.var()) / T ** 2
168
+ binder = float(1.0 - m4.mean() / (3.0 * m2.mean() ** 2)) if m2.mean() > 0 else float("nan")
169
+
170
+ return {
171
+ "<|m|>": float(np.abs(m).mean()),
172
+ "chi": chi,
173
+ "C_v": Cv,
174
+ "U4": binder,
175
+ }
176
+
177
+
178
+ # ---------------------------------------------------------------------------
179
+ # Mahalanobis distance
180
+ # ---------------------------------------------------------------------------
181
+
182
+ def mahalanobis_distance(
183
+ X_ref: np.ndarray,
184
+ X_query: np.ndarray,
185
+ reg: float = 1e-6,
186
+ ) -> tuple[float, np.ndarray]:
187
+ """Mahalanobis distance of the query-mean from the reference distribution.
188
+
189
+ D = sqrt( Δμ^T Ξ£_ref^{-1} Δμ )
190
+
191
+ Also returns per-feature z-scores z_i = Δμ_i / Οƒ_ref_i,
192
+ where Οƒ_ref_i = sqrt(Ξ£_ref[i,i]). |z_i| > 1 indicates a feature
193
+ whose mean differs by more than one real-sample standard deviation.
194
+
195
+ Parameters
196
+ ----------
197
+ X_ref : (N, d) real / reference feature matrix
198
+ X_query : (M, d) generated / query feature matrix
199
+ reg : diagonal regularisation added to Ξ£_ref before inversion
200
+ """
201
+ mu_ref = X_ref.mean(axis=0)
202
+ mu_query = X_query.mean(axis=0)
203
+ cov = np.cov(X_ref.T) + reg * np.eye(X_ref.shape[1])
204
+ cov_inv = np.linalg.inv(cov)
205
+ delta = mu_query - mu_ref
206
+ D = float(np.sqrt(max(0.0, delta @ cov_inv @ delta)))
207
+ z_scores = delta / np.sqrt(np.diag(cov))
208
+ return D, z_scores
209
+
210
+
211
+ # ---------------------------------------------------------------------------
212
+ # Reporting
213
+ # ---------------------------------------------------------------------------
214
+
215
+ def print_feature_table(X_real: np.ndarray, X_gen: np.ndarray) -> None:
216
+ mu_r = X_real.mean(axis=0)
217
+ sd_r = X_real.std(axis=0)
218
+ mu_g = X_gen.mean(axis=0)
219
+ sd_g = X_gen.std(axis=0)
220
+
221
+ col = 13
222
+ hdr = (f" {'Feature':<11} {'Real mean':>{col}} {'Real std':>{col}}"
223
+ f" {'Gen mean':>{col}} {'Gen std':>{col}} {'z-score':>8}")
224
+ print(hdr)
225
+ print(" " + "─" * (len(hdr) - 2))
226
+ for name, mr, sr, mg, sg in zip(FEATURE_NAMES, mu_r, sd_r, mu_g, sd_g):
227
+ z = (mg - mr) / (sr + 1e-12)
228
+ flag = " <" if abs(z) > 1.0 else ""
229
+ print(f" {name:<11} {mr:>{col}.4f} {sr:>{col}.4f}"
230
+ f" {mg:>{col}.4f} {sg:>{col}.4f} {z:>+8.3f}{flag}")
231
+ print()
232
+
233
+
234
+ def print_ensemble_table(stats_real: dict, stats_gen: dict) -> None:
235
+ labels = {
236
+ "<|m|>": "mean |m|",
237
+ "chi": "chi (susceptibility)",
238
+ "C_v": "C_v (specific heat)",
239
+ "U4": "U4 (Binder cumulant)",
240
+ }
241
+ print(f" {'Observable':<26} {'Real':>10} {'Generated':>10}")
242
+ print(" " + "─" * 50)
243
+ for key, label in labels.items():
244
+ r = stats_real[key]
245
+ g = stats_gen[key]
246
+ print(f" {label:<26} {r:>10.4f} {g:>10.4f}")
247
+ print()
248
+
249
+
250
+ # ---------------------------------------------------------------------------
251
+ # CLI
252
+ # ---------------------------------------------------------------------------
253
+
254
+ _SAMPLE_BATCH = 4 # fixed vmapped batch; changing triggers recompilation
255
+
256
+
257
+ def generate_grids(model, n: int, key: jax.Array, L: int) -> np.ndarray:
258
+ """Sample n grids in batches of _SAMPLE_BATCH with a progress bar.
259
+
260
+ Using a fixed batch size means only one JIT compilation happens regardless
261
+ of n. The final partial batch is padded then trimmed.
262
+ """
263
+ batches = []
264
+ n_full, remainder = divmod(n, _SAMPLE_BATCH)
265
+ n_batches = n_full + (1 if remainder else 0)
266
+
267
+ with tqdm(total=n, unit="samples", desc="Sampling") as pbar:
268
+ for i in range(n_batches):
269
+ key, subkey = jax.random.split(key)
270
+ tokens = np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey))
271
+ batches.append(tokens)
272
+ pbar.update(min(_SAMPLE_BATCH, n - i * _SAMPLE_BATCH))
273
+
274
+ return tokens_to_grids(np.concatenate(batches)[:n], L)
275
+
276
+
277
+ def load_test_grids(
278
+ test_data: Path | None,
279
+ data: Path,
280
+ n: int,
281
+ L: int,
282
+ rng: np.random.Generator,
283
+ ) -> np.ndarray:
284
+ """Load real test grids, preferring a dedicated test file over the val split.
285
+
286
+ Parameters
287
+ ----------
288
+ test_data : optional path to a standalone test .npy file (N, L, L) int8 {-1,+1}
289
+ data : path to the main spins.npy (used only if test_data is None)
290
+ """
291
+ if test_data is not None:
292
+ spins = np.load(test_data) # (N, L, L) int8
293
+ tokens = (spins.astype(np.int32) + 1) // 2 # β†’ {0, 1}
294
+ rows, cols = snake_order(L)
295
+ tokens = tokens[:, rows, cols] # (N, LΒ²)
296
+ else:
297
+ _, tokens = load_ising_data(data) # val split of spins.npy
298
+
299
+ n = min(n, len(tokens))
300
+ idx = rng.choice(len(tokens), size=n, replace=False)
301
+ return tokens_to_grids(tokens[idx], L) # (n, L, L), values Β±1
302
+
303
+
304
+ def parse_args():
305
+ p = argparse.ArgumentParser(
306
+ description="Compare generated vs real Ising samples via physical observables."
307
+ )
308
+ p.add_argument("--checkpoint", type=Path, required=True,
309
+ help="Path to the .eqx checkpoint file.")
310
+ p.add_argument("--data", type=Path,
311
+ default=Path(__file__).parent / "spins.npy",
312
+ help="Path to spins.npy (default: ./spins.npy). "
313
+ "Used only if --test-data is not provided.")
314
+ p.add_argument("--test-data", type=Path,
315
+ default=Path(__file__).parent / "spins_test.npy",
316
+ help="Dedicated held-out test set (.npy, NΓ—LΓ—L int8 {-1,+1}). "
317
+ "Takes priority over the val split of --data.")
318
+ p.add_argument("--num-samples", type=int, default=50,
319
+ help="Number of samples to compare (default: 50).")
320
+ p.add_argument("--samples-file", type=Path, default=None,
321
+ help="Optional .npy of pre-generated {-1,+1} grids (N,L,L) "
322
+ "from 'sample.py --output'. Skips generation entirely.")
323
+ p.add_argument("--seed", type=int, default=0)
324
+ return p.parse_args()
325
+
326
+
327
+ def main():
328
+ args = parse_args()
329
+ L = gen_config["lattice_size"]
330
+ rng = np.random.default_rng(args.seed)
331
+
332
+ # ── Real samples (test split) ─────────────────────────────────────────────
333
+ # Prefer spins_test.npy; fall back to val split of spins.npy.
334
+ test_path = args.test_data if (args.test_data and args.test_data.exists()) else None
335
+ if test_path:
336
+ print(f"Loading test data from {test_path} …")
337
+ else:
338
+ print("Loading test data from val split of spins.npy …")
339
+ n = args.num_samples
340
+ real_grids = load_test_grids(test_path, args.data, n, L, rng)
341
+ n = len(real_grids) # may be capped by dataset size
342
+
343
+ # ── Generated samples ─────────────────────────────────────────────────────
344
+ if args.samples_file is not None:
345
+ print(f"Loading pre-generated samples from {args.samples_file} …")
346
+ gen_grids = np.load(args.samples_file).astype(np.int8)[:n]
347
+ if gen_grids.shape[1:] != (L, L):
348
+ raise ValueError(
349
+ f"samples-file grid shape {gen_grids.shape[1:]} != ({L},{L})"
350
+ )
351
+ n = min(n, len(gen_grids))
352
+ real_grids = real_grids[:n]
353
+ else:
354
+ print(f"Loading checkpoint from {args.checkpoint} …")
355
+ model = load_checkpoint(args.checkpoint)
356
+ key = jax.random.PRNGKey(args.seed)
357
+ gen_grids = generate_grids(model, n, key, L) # (n, L, L), values Β±1
358
+
359
+ print(f"\nL = {L} | N = {n} samples per group | T_C = {T_C:.6f}\n")
360
+
361
+ # ── Feature matrices ──────────────────────────────────────────────────────
362
+ X_real = compute_feature_matrix(real_grids, desc="Features: real ")
363
+ X_gen = compute_feature_matrix(gen_grids, desc="Features: generated ")
364
+
365
+ # ── Per-feature comparison table ──────────────────────────────────────────
366
+ print("Per-feature statistics (z-score = Δμ / Οƒ_real; '<' marks |z| > 1)\n")
367
+ print_feature_table(X_real, X_gen)
368
+
369
+ # ── Ensemble statistics ───────────────────────────────────────────────────
370
+ print("Ensemble statistics\n")
371
+ print_ensemble_table(ensemble_stats(X_real), ensemble_stats(X_gen))
372
+
373
+ # ── Mahalanobis distance ──────────────────────────────────────────────────
374
+ D, z = mahalanobis_distance(X_real, X_gen)
375
+ print(f"Mahalanobis distance D = {D:.4f}")
376
+ print( " (D measures how many 'std-devs' the generated feature mean sits")
377
+ print( " from the real distribution in the decorrelated feature space.)")
378
+ print()
379
+ print(" Top deviating features:")
380
+ order = np.argsort(np.abs(z))[::-1]
381
+ for i in order[:5]:
382
+ print(f" {FEATURE_NAMES[i]:<11} z = {z[i]:+.3f}")
383
+
384
+
385
+ if __name__ == "__main__":
386
+ main()
examples.png ADDED
ising.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Legacy entry point. This file has been split into:
2
+ model.py β€” Generator architecture and gen_config
3
+ train.py β€” training loop (python train.py --help)
4
+ sample.py β€” checkpoint loading and autoregressive sampling (python sample.py --help)
5
+ """
6
+ # Re-export model symbols for any code that still does `from ising import ...`
7
+ from model import ( # noqa: F401
8
+ snake_order,
9
+ EmbedderBlock,
10
+ FeedForwardBlock,
11
+ AttentionBlock,
12
+ TransformerLayer,
13
+ Encoder,
14
+ Generator,
15
+ gen_config,
16
+ )
17
+
18
+ if __name__ == "__main__":
19
+ import train
20
+ train.main()
main.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import sys, jax
6
+ sys.path.insert(0, '.')
7
+ from sample import load_checkpoint, sample_batch, tokens_to_grids
8
+ model = load_checkpoint('checkpoint.eqx')
9
+ lattice_size = model.encoder.embedder_block.lattice_size
10
+ key = jax.random.PRNGKey(42)
11
+ print('Compiling and sampling 16 configurations...')
12
+ tokens = sample_batch(model, 16, key)
13
+ tokens = np.asarray(tokens)
14
+ grids = tokens_to_grids(tokens, lattice_size)
15
+ mags = grids.mean(axis=(1,2))
16
+ print(f'Magnetizations: {np.round(mags, 3)}')
17
+ print(f'Mean |m|: {np.abs(mags).mean():.4f}')
18
+ fig, axes = plt.subplots(2, 8, figsize=(14, 4))
19
+ for i, ax in enumerate(axes.flat):
20
+ ax.imshow(grids[i], cmap='gray', vmin=-1, vmax=1, interpolation='nearest')
21
+ ax.set_title(f'm={mags[i]:.2f}', fontsize=7)
22
+ ax.axis('off')
23
+ fig.suptitle('Sampled Ising configs (L=32, T=T_c, 2 epochs)', fontsize=10)
24
+ plt.tight_layout()
25
+ plt.savefig('samples.png', dpi=150, bbox_inches='tight')
26
+ print('Saved samples.png')
metadata.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "lattice_size": 32,
3
+ "sample_count": 10000,
4
+ "temperature": 2.269185314213022,
5
+ "temperature_note": "2D Ising critical temperature T_c = 2J/ln(1+sqrt(2))",
6
+ "coupling": 1.0,
7
+ "spin_values": [
8
+ -1,
9
+ 1
10
+ ],
11
+ "burn_in_sweeps": 200,
12
+ "sample_interval_sweeps": 5,
13
+ "n_chains": 10,
14
+ "algorithm": "Wolff single-cluster",
15
+ "base_seed": 1778172101,
16
+ "generation_time_s": 4.9
17
+ }
model.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer model for autoregressive Ising spin generation.
2
+
3
+ Architecture: causal (GPT-style) transformer with per-site positional
4
+ embeddings in snake (boustrophedon) order. The model is trained to maximise
5
+ p(s_0, s_1, ..., s_{N-1}) = ∏_t p(s_t | s_0, ..., s_{t-1}), where the spin
6
+ sites are visited in snake order over the LΓ—L lattice.
7
+ """
8
+
9
+ from collections.abc import Mapping
10
+
11
+ import equinox as eqx
12
+ import jax
13
+ import jax.numpy as jnp
14
+ import numpy as np
15
+ from jaxtyping import Array, Float, Int
16
+
17
+
18
+ def snake_order(size: int) -> tuple[np.ndarray, np.ndarray]:
19
+ """Return (rows, cols) index arrays traversing an LΓ—L grid in snake order.
20
+
21
+ Even rows go left-to-right; odd rows go right-to-left. The returned
22
+ arrays have length sizeΒ² and implement numpy advanced indexing:
23
+ grid[rows, cols] β†’ 1-D sequence in snake order
24
+ grid[rows, cols] = seq β†’ scatter a sequence back to the grid
25
+ """
26
+ if size <= 0:
27
+ raise ValueError("size must be positive")
28
+ rows, cols = [], []
29
+ for row in range(size):
30
+ columns = range(size) if row % 2 == 0 else range(size - 1, -1, -1)
31
+ for col in columns:
32
+ rows.append(row)
33
+ cols.append(col)
34
+ return np.array(rows), np.array(cols)
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Building blocks
39
+ # ---------------------------------------------------------------------------
40
+
41
+ class EmbedderBlock(eqx.Module):
42
+ """Spin-state + lattice-position embedder.
43
+
44
+ Each position in the snake-order sequence gets three embeddings summed:
45
+ β€’ a learned spin-state embedding (token ∈ {0, 1})
46
+ β€’ a learned row-position embedding
47
+ β€’ a learned column-position embedding
48
+
49
+ The row/column indices are derived from `snake_order` at trace time, so
50
+ they fold to compile-time constants β€” no array model-parameters are stored.
51
+ """
52
+
53
+ state_embedder: eqx.nn.Embedding
54
+ row_embedder: eqx.nn.Embedding
55
+ column_embedder: eqx.nn.Embedding
56
+ layernorm: eqx.nn.LayerNorm
57
+ dropout: eqx.nn.Dropout
58
+ lattice_size: int = eqx.field(static=True)
59
+
60
+ def __init__(
61
+ self,
62
+ state_size: int,
63
+ lattice_size: int,
64
+ embedding_size: int,
65
+ hidden_size: int,
66
+ dropout_rate: float,
67
+ key: jax.random.PRNGKey,
68
+ ):
69
+ state_key, row_key, col_key = jax.random.split(key, 3)
70
+ self.state_embedder = eqx.nn.Embedding(
71
+ num_embeddings=state_size, embedding_size=embedding_size, key=state_key
72
+ )
73
+ self.row_embedder = eqx.nn.Embedding(
74
+ num_embeddings=lattice_size, embedding_size=embedding_size, key=row_key
75
+ )
76
+ self.column_embedder = eqx.nn.Embedding(
77
+ num_embeddings=lattice_size, embedding_size=embedding_size, key=col_key
78
+ )
79
+ self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
80
+ self.dropout = eqx.nn.Dropout(dropout_rate)
81
+ self.lattice_size = lattice_size
82
+
83
+ def __call__(
84
+ self,
85
+ states: Int[Array, " seq_len"],
86
+ enable_dropout: bool = False,
87
+ key: jax.Array | None = None,
88
+ ) -> Float[Array, "seq_len hidden_size"]:
89
+ rows, cols = snake_order(self.lattice_size) # concrete at trace time
90
+ x_states = jax.vmap(self.state_embedder)(states)
91
+ x_rows = jax.vmap(self.row_embedder)(jnp.asarray(rows))
92
+ x_cols = jax.vmap(self.column_embedder)(jnp.asarray(cols))
93
+ x = x_states + x_rows + x_cols
94
+ x = jax.vmap(self.layernorm)(x)
95
+ x = self.dropout(x, inference=not enable_dropout, key=key)
96
+ return x
97
+
98
+
99
+ class FeedForwardBlock(eqx.Module):
100
+ """Position-wise feed-forward block with residual connection."""
101
+
102
+ mlp: eqx.nn.Linear
103
+ output: eqx.nn.Linear
104
+ layernorm: eqx.nn.LayerNorm
105
+ dropout: eqx.nn.Dropout
106
+
107
+ def __init__(
108
+ self,
109
+ hidden_size: int,
110
+ intermediate_size: int,
111
+ dropout_rate: float,
112
+ key: jax.random.PRNGKey,
113
+ ):
114
+ mlp_key, out_key = jax.random.split(key)
115
+ self.mlp = eqx.nn.Linear(hidden_size, intermediate_size, key=mlp_key)
116
+ self.output = eqx.nn.Linear(intermediate_size, hidden_size, key=out_key)
117
+ self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
118
+ self.dropout = eqx.nn.Dropout(dropout_rate)
119
+
120
+ def __call__(
121
+ self,
122
+ inputs: Float[Array, " hidden_size"],
123
+ enable_dropout: bool = False,
124
+ key: jax.Array | None = None,
125
+ ) -> Float[Array, " hidden_size"]:
126
+ x = jax.nn.gelu(self.mlp(inputs))
127
+ x = self.output(x)
128
+ x = self.dropout(x, inference=not enable_dropout, key=key)
129
+ x = x + inputs
130
+ x = self.layernorm(x)
131
+ return x
132
+
133
+
134
+ class AttentionBlock(eqx.Module):
135
+ """Multi-head self-attention with causal (lower-triangular) mask."""
136
+
137
+ attention: eqx.nn.MultiheadAttention
138
+ layernorm: eqx.nn.LayerNorm
139
+ dropout: eqx.nn.Dropout
140
+ num_heads: int = eqx.field(static=True)
141
+
142
+ def __init__(
143
+ self,
144
+ hidden_size: int,
145
+ num_heads: int,
146
+ dropout_rate: float,
147
+ attention_dropout_rate: float,
148
+ key: jax.random.PRNGKey,
149
+ ):
150
+ self.num_heads = num_heads
151
+ self.attention = eqx.nn.MultiheadAttention(
152
+ num_heads=num_heads,
153
+ query_size=hidden_size,
154
+ use_query_bias=True,
155
+ use_key_bias=True,
156
+ use_value_bias=True,
157
+ use_output_bias=True,
158
+ dropout_p=attention_dropout_rate,
159
+ key=key,
160
+ )
161
+ self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
162
+ self.dropout = eqx.nn.Dropout(dropout_rate)
163
+
164
+ def __call__(
165
+ self,
166
+ inputs: Float[Array, "seq_len hidden_size"],
167
+ mask: Int[Array, " seq_len"] | None,
168
+ enable_dropout: bool = False,
169
+ key: jax.random.PRNGKey = None,
170
+ ) -> Float[Array, "seq_len hidden_size"]:
171
+ attn_key, drop_key = (None, None) if key is None else jax.random.split(key)
172
+ if mask is not None:
173
+ mask = self._causal_mask(mask)
174
+ x = self.attention(
175
+ query=inputs, key_=inputs, value=inputs,
176
+ mask=mask, inference=not enable_dropout, key=attn_key,
177
+ )
178
+ x = self.dropout(x, inference=not enable_dropout, key=drop_key)
179
+ x = x + inputs
180
+ x = jax.vmap(self.layernorm)(x)
181
+ return x
182
+
183
+ def _causal_mask(
184
+ self, mask: Int[Array, " seq_len"]
185
+ ) -> Float[Array, "num_heads seq_len seq_len"]:
186
+ """Lower-triangular mask combined with a padding mask."""
187
+ n = mask.shape[0]
188
+ pad = jnp.multiply(mask[:, None], mask[None, :]) # [n, n]
189
+ causal = jnp.tril(jnp.ones((n, n), dtype=mask.dtype)) # [n, n]
190
+ m = jnp.multiply(pad, causal) # [n, n]
191
+ m = jnp.broadcast_to(m[None], (self.num_heads, n, n)) # [H, n, n]
192
+ return m.astype(jnp.float32)
193
+
194
+
195
+ class TransformerLayer(eqx.Module):
196
+ """One transformer block: attention followed by feed-forward."""
197
+
198
+ attention_block: AttentionBlock
199
+ ff_block: FeedForwardBlock
200
+
201
+ def __init__(
202
+ self,
203
+ hidden_size: int,
204
+ intermediate_size: int,
205
+ num_heads: int,
206
+ dropout_rate: float,
207
+ attention_dropout_rate: float,
208
+ key: jax.random.PRNGKey,
209
+ ):
210
+ attn_key, ff_key = jax.random.split(key)
211
+ self.attention_block = AttentionBlock(
212
+ hidden_size=hidden_size,
213
+ num_heads=num_heads,
214
+ dropout_rate=dropout_rate,
215
+ attention_dropout_rate=attention_dropout_rate,
216
+ key=attn_key,
217
+ )
218
+ self.ff_block = FeedForwardBlock(
219
+ hidden_size=hidden_size,
220
+ intermediate_size=intermediate_size,
221
+ dropout_rate=dropout_rate,
222
+ key=ff_key,
223
+ )
224
+
225
+ def __call__(
226
+ self,
227
+ inputs: Float[Array, "seq_len hidden_size"],
228
+ mask: Int[Array, " seq_len"] | None = None,
229
+ *,
230
+ enable_dropout: bool = False,
231
+ key: jax.Array | None = None,
232
+ ) -> Float[Array, "seq_len hidden_size"]:
233
+ attn_key, ff_key = (None, None) if key is None else jax.random.split(key)
234
+ x = self.attention_block(inputs, mask, enable_dropout=enable_dropout, key=attn_key)
235
+ n = x.shape[0]
236
+ ff_keys = None if ff_key is None else jax.random.split(ff_key, n)
237
+ x = jax.vmap(self.ff_block, in_axes=(0, None, 0))(x, enable_dropout, ff_keys)
238
+ return x
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # Encoder and top-level Generator
243
+ # ---------------------------------------------------------------------------
244
+
245
+ class Encoder(eqx.Module):
246
+ """Stack of transformer layers over a snake-ordered spin sequence."""
247
+
248
+ embedder_block: EmbedderBlock
249
+ layers: list[TransformerLayer]
250
+
251
+ def __init__(
252
+ self,
253
+ state_size: int,
254
+ lattice_size: int,
255
+ embedding_size: int,
256
+ hidden_size: int,
257
+ intermediate_size: int,
258
+ num_layers: int,
259
+ num_heads: int,
260
+ dropout_rate: float,
261
+ attention_dropout_rate: float,
262
+ key: jax.random.PRNGKey,
263
+ ):
264
+ emb_key, layer_key = jax.random.split(key)
265
+ self.embedder_block = EmbedderBlock(
266
+ state_size=state_size,
267
+ lattice_size=lattice_size,
268
+ embedding_size=embedding_size,
269
+ hidden_size=hidden_size,
270
+ dropout_rate=dropout_rate,
271
+ key=emb_key,
272
+ )
273
+ layer_keys = jax.random.split(layer_key, num_layers)
274
+ self.layers = [
275
+ TransformerLayer(
276
+ hidden_size=hidden_size,
277
+ intermediate_size=intermediate_size,
278
+ num_heads=num_heads,
279
+ dropout_rate=dropout_rate,
280
+ attention_dropout_rate=attention_dropout_rate,
281
+ key=lk,
282
+ )
283
+ for lk in layer_keys
284
+ ]
285
+
286
+ def __call__(
287
+ self,
288
+ states: Int[Array, " seq_len"],
289
+ *,
290
+ enable_dropout: bool = False,
291
+ key: jax.Array | None = None,
292
+ ) -> Float[Array, "seq_len hidden_size"]:
293
+ emb_key, l_key = (None, None) if key is None else jax.random.split(key)
294
+ x = self.embedder_block(states, enable_dropout=enable_dropout, key=emb_key)
295
+ mask = jnp.ones_like(states, dtype=jnp.int32) # no padding; causal only
296
+ for layer in self.layers:
297
+ cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key)
298
+ x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key)
299
+ return x
300
+
301
+
302
+ class Generator(eqx.Module):
303
+ """Autoregressive transformer generator for Ising spin configurations.
304
+
305
+ Input: token_ids β€” integer spin tokens {0=down, 1=up} in snake order.
306
+ Output: logits β€” shape (seq_len, state_size), where logits[t] is the
307
+ predicted distribution over the spin at position t+1
308
+ given positions 0..t.
309
+ """
310
+
311
+ encoder: Encoder
312
+ lm_head: eqx.nn.Linear
313
+ dropout: eqx.nn.Dropout
314
+
315
+ def __init__(self, config: Mapping, key: jax.random.PRNGKey):
316
+ enc_key, head_key = jax.random.split(key)
317
+ self.encoder = Encoder(
318
+ state_size=config["state_size"],
319
+ lattice_size=config["lattice_size"],
320
+ embedding_size=config["hidden_size"],
321
+ hidden_size=config["hidden_size"],
322
+ intermediate_size=config["intermediate_size"],
323
+ num_layers=config["num_hidden_layers"],
324
+ num_heads=config["num_attention_heads"],
325
+ dropout_rate=config["hidden_dropout_prob"],
326
+ attention_dropout_rate=config["attention_probs_dropout_prob"],
327
+ key=enc_key,
328
+ )
329
+ self.lm_head = eqx.nn.Linear(
330
+ in_features=config["hidden_size"],
331
+ out_features=config["state_size"],
332
+ key=head_key,
333
+ )
334
+ self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"])
335
+
336
+ def __call__(
337
+ self,
338
+ inputs: dict[str, Int[Array, " seq_len"]],
339
+ enable_dropout: bool = False,
340
+ key: jax.random.PRNGKey = None,
341
+ ) -> Float[Array, "seq_len state_size"]:
342
+ e_key, d_key = (None, None) if key is None else jax.random.split(key)
343
+ x = self.encoder(inputs["token_ids"], enable_dropout=enable_dropout, key=e_key)
344
+ x = self.dropout(x, inference=not enable_dropout, key=d_key)
345
+ return jax.vmap(self.lm_head)(x)
346
+
347
+
348
+ # ---------------------------------------------------------------------------
349
+ # Default configuration
350
+ # ---------------------------------------------------------------------------
351
+
352
+ gen_config = {
353
+ "state_size": 2, # spin tokens: 0 (↓) or 1 (↑)
354
+ "lattice_size": 32, # LΓ—L lattice β†’ LΒ² = 1024 sequence length
355
+ "hidden_size": 128,
356
+ "num_hidden_layers": 2,
357
+ "num_attention_heads": 2,
358
+ "hidden_act": "gelu",
359
+ "intermediate_size": 512,
360
+ "hidden_dropout_prob": 0.1,
361
+ "attention_probs_dropout_prob": 0.1,
362
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ jax[cuda12]
2
+ equinox
3
+ optax
4
+ einops
5
+ tqdm
6
+ jaxtyping
7
+ gradio>=4.0
8
+ matplotlib
9
+ scipy
sample.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # /// script
3
+ # dependencies = [
4
+ # "jax[cuda12]",
5
+ # "equinox",
6
+ # "matplotlib",
7
+ # "jaxtyping",
8
+ # ]
9
+ # ///
10
+ """Sample spin configurations from a trained Ising Generator checkpoint.
11
+
12
+ Usage:
13
+ python sample.py --checkpoint model.eqx [--num-samples 16]
14
+ [--output samples.npy] [--plot] [--seed 0]
15
+
16
+ How autoregressive sampling works
17
+ ----------------------------------
18
+ The model is trained with a causal (lower-triangular) attention mask, so at
19
+ position t the output logits[t] are a function of spins s_0 … s_t only.
20
+ We exploit this to sample the full sequence one spin at a time:
21
+
22
+ 1. Sample s_0 uniformly (the model has no BOS token).
23
+ 2. For t = 0, 1, …, LΒ²-2:
24
+ a. Run the full forward pass on the current token buffer.
25
+ Spins at positions > t are still placeholder zeros, but causal
26
+ masking prevents the network from attending to them.
27
+ b. Draw s_{t+1} ~ Categorical(softmax(logits[t])).
28
+ c. Write s_{t+1} into the buffer.
29
+
30
+ This is O(L⁴) in compute (LΒ² steps Γ— LΒ² attention), which is 1 B ops for a
31
+ 32Γ—32 lattice. `jax.lax.scan` compiles the loop body once so subsequent
32
+ calls are fast.
33
+ """
34
+
35
+ import argparse
36
+ import functools
37
+ from pathlib import Path
38
+
39
+ import equinox as eqx
40
+ import jax
41
+ import jax.numpy as jnp
42
+ import numpy as np
43
+ from jaxtyping import Array, Int
44
+
45
+ from model import Generator, gen_config, snake_order
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Checkpoint I/O
50
+ # ---------------------------------------------------------------------------
51
+
52
+ def load_checkpoint(
53
+ path: Path,
54
+ config: dict = gen_config,
55
+ key: jax.Array | None = None,
56
+ ) -> Generator:
57
+ """Deserialise a Generator from *path*.
58
+
59
+ A fresh model is initialised with *config* (weights are immediately
60
+ overwritten), so *key* only needs to be reproducible across calls if you
61
+ care about the random seed used for the skeleton β€” in practice any key
62
+ works.
63
+ """
64
+ if key is None:
65
+ key = jax.random.PRNGKey(0)
66
+ skeleton = Generator(config=config, key=key)
67
+ return eqx.tree_deserialise_leaves(path, skeleton)
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # Sampling
72
+ # ---------------------------------------------------------------------------
73
+
74
+ def sample_sequence(
75
+ model: Generator,
76
+ key: jax.random.PRNGKey,
77
+ ) -> Int[Array, " seq_len"]:
78
+ """Autoregressively sample one spin sequence in snake order.
79
+
80
+ This function is JAX-traceable and safe to use inside ``jax.vmap`` or
81
+ ``jax.lax.scan``. JIT-compile it (or wrap in ``sample_batch``) for best
82
+ performance β€” the first call will take longer due to compilation.
83
+ """
84
+ lattice_size = model.encoder.embedder_block.lattice_size
85
+ seq_len = lattice_size * lattice_size
86
+
87
+ def step(carry, t):
88
+ tokens, step_key = carry
89
+ step_key, sample_key = jax.random.split(step_key)
90
+ logits = model({"token_ids": tokens}, enable_dropout=False, key=None)
91
+ # logits[t] β†’ distribution over s_{t+1}
92
+ next_token = jax.random.categorical(sample_key, logits[t])
93
+ tokens = tokens.at[t + 1].set(next_token)
94
+ return (tokens, step_key), None
95
+
96
+ key, first_key = jax.random.split(key)
97
+ first_token = jax.random.randint(first_key, shape=(), minval=0, maxval=2)
98
+ tokens = jnp.zeros(seq_len, dtype=jnp.int32).at[0].set(first_token)
99
+ (tokens, _), _ = jax.lax.scan(step, (tokens, key), jnp.arange(seq_len - 1))
100
+ return tokens
101
+
102
+
103
+ @eqx.filter_jit
104
+ def sample_batch(
105
+ model: Generator,
106
+ num_samples: int,
107
+ key: jax.random.PRNGKey,
108
+ ) -> Int[Array, "num_samples seq_len"]:
109
+ """Sample *num_samples* configurations in parallel (vmapped + JIT'd).
110
+
111
+ The first call triggers compilation; subsequent calls with the same
112
+ ``num_samples`` reuse the compiled code.
113
+ """
114
+ keys = jax.random.split(key, num_samples)
115
+ return jax.vmap(sample_sequence, in_axes=(None, 0))(model, keys)
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # Grid conversion
120
+ # ---------------------------------------------------------------------------
121
+
122
+ def tokens_to_grid(
123
+ tokens: np.ndarray | Array,
124
+ lattice_size: int,
125
+ ) -> np.ndarray:
126
+ """Convert a snake-ordered token sequence {0, 1} β†’ a {-1, +1} LΓ—L grid."""
127
+ tokens = np.asarray(tokens)
128
+ rows, cols = snake_order(lattice_size)
129
+ grid = np.empty((lattice_size, lattice_size), dtype=np.int8)
130
+ grid[rows, cols] = (tokens * 2 - 1).astype(np.int8)
131
+ return grid
132
+
133
+
134
+ def tokens_to_grids(
135
+ tokens: np.ndarray | Array,
136
+ lattice_size: int,
137
+ ) -> np.ndarray:
138
+ """Batch version of ``tokens_to_grid``. Input shape: (N, LΒ²)."""
139
+ tokens = np.asarray(tokens)
140
+ return np.stack([tokens_to_grid(tokens[i], lattice_size) for i in range(len(tokens))])
141
+
142
+
143
+ # ---------------------------------------------------------------------------
144
+ # CLI
145
+ # ---------------------------------------------------------------------------
146
+
147
+ def parse_args():
148
+ p = argparse.ArgumentParser(description="Sample from a trained Ising Generator.")
149
+ p.add_argument("--checkpoint", type=Path, required=True,
150
+ help="Path to the .eqx checkpoint file.")
151
+ p.add_argument("--num-samples", type=int, default=16)
152
+ p.add_argument("--output", type=Path, default=None,
153
+ help="Save sampled {-1,+1} grids as a .npy file (N, L, L).")
154
+ p.add_argument("--plot", action="store_true",
155
+ help="Display a grid of sampled configurations with matplotlib.")
156
+ p.add_argument("--seed", type=int, default=0)
157
+ return p.parse_args()
158
+
159
+
160
+ def main():
161
+ args = parse_args()
162
+
163
+ print(f"Loading checkpoint from {args.checkpoint} …")
164
+ model = load_checkpoint(args.checkpoint)
165
+ lattice_size = model.encoder.embedder_block.lattice_size
166
+ print(f" lattice_size={lattice_size}, seq_len={lattice_size**2}")
167
+
168
+ key = jax.random.PRNGKey(args.seed)
169
+ print(f"Sampling {args.num_samples} configurations "
170
+ f"(compiling on first call) …")
171
+ tokens = sample_batch(model, args.num_samples, key)
172
+ tokens = np.asarray(tokens)
173
+
174
+ grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1}
175
+ print(f" shape: {grids.shape} dtype: {grids.dtype}")
176
+ print(f" mean magnetization : {grids.mean():.4f}")
177
+ print(f" mean |magnetization|: {np.abs(grids.mean(axis=(1, 2))).mean():.4f}")
178
+
179
+ if args.output is not None:
180
+ np.save(args.output, grids)
181
+ print(f"Saved β†’ {args.output}")
182
+
183
+ if args.plot:
184
+ try:
185
+ import matplotlib.pyplot as plt
186
+ except ImportError:
187
+ print("matplotlib not available; skipping plot (pip install matplotlib).")
188
+ return
189
+
190
+ cols = min(8, args.num_samples)
191
+ rows = (args.num_samples + cols - 1) // cols
192
+ fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
193
+ axes = np.array(axes).reshape(-1)
194
+ for i, ax in enumerate(axes):
195
+ if i < len(grids):
196
+ ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest")
197
+ ax.axis("off")
198
+ fig.suptitle(
199
+ f"Sampled Ising configurations "
200
+ f"(L={lattice_size}, n={args.num_samples})",
201
+ fontsize=10,
202
+ )
203
+ plt.tight_layout()
204
+ plt.show()
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
samples-2-epoch.png ADDED
spins.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1890a0f2baeee4212b3bd79b74bff1f1f31135d9031ffb8c79bc558eae23a83
3
+ size 10240128
spins_test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:040dea2a4f3b5753c7531d7a55643a7be3c8be4a7d9b9c1f11f481ea522ff67e
3
+ size 1024128
train.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # /// script
3
+ # dependencies = [
4
+ # "jax[cuda12]",
5
+ # "equinox",
6
+ # "optax",
7
+ # "einops",
8
+ # "tqdm",
9
+ # "jaxtyping",
10
+ # ]
11
+ # ///
12
+ """Training script for the Ising spin Generator.
13
+
14
+ Usage:
15
+ python train.py [--epochs N] [--batch-size B] [--learning-rate LR]
16
+ [--data path/to/spins.npy] [--output-checkpoint model.eqx]
17
+ """
18
+
19
+ import argparse
20
+ import functools
21
+ from pathlib import Path
22
+
23
+ import einops
24
+ import equinox as eqx
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import numpy as np
28
+ import optax
29
+ from tqdm.auto import tqdm
30
+
31
+ from model import Generator, gen_config, snake_order
32
+
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Data loading
36
+ # ---------------------------------------------------------------------------
37
+
38
+ def load_ising_data(
39
+ path: Path, train_frac: float = 0.9
40
+ ) -> tuple[np.ndarray, np.ndarray]:
41
+ """Load spins.npy, map {-1,1} β†’ {0,1}, flatten with snake ordering.
42
+
43
+ Returns ``(train_tokens, val_tokens)``, each ``(N, LΒ²)`` int32.
44
+ """
45
+ spins = np.load(path) # (N, L, L) int8
46
+ lattice_size = spins.shape[1]
47
+ tokens = (spins.astype(np.int32) + 1) // 2 # (N, L, L), values in {0,1}
48
+ rows, cols = snake_order(lattice_size)
49
+ tokens = tokens[:, rows, cols] # (N, LΒ²)
50
+ n_train = int(len(tokens) * train_frac)
51
+ return tokens[:n_train], tokens[n_train:]
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Batch preparation
56
+ # ---------------------------------------------------------------------------
57
+
58
+ def prepare_batch(batch: np.ndarray, num_devices: int) -> dict:
59
+ """Reshape ``(batch, seq)`` β†’ ``(devices, batch//devices, seq)`` for pmap."""
60
+ token_ids = einops.rearrange(
61
+ batch,
62
+ "(devices batch) seq -> devices batch seq",
63
+ devices=num_devices,
64
+ )
65
+ return {"token_ids": token_ids}
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Training / eval steps
70
+ # ---------------------------------------------------------------------------
71
+
72
+ @eqx.filter_value_and_grad
73
+ def compute_loss(model, inputs, key):
74
+ """Autoregressive cross-entropy: logits[:, :-1] predicts token_ids[:, 1:]."""
75
+ batch_size = inputs["token_ids"].shape[0]
76
+ keys = jax.random.split(key, batch_size)
77
+ logits = jax.vmap(model, in_axes=(0, None, 0))(inputs, True, keys)
78
+ return jnp.mean(
79
+ optax.softmax_cross_entropy_with_integer_labels(
80
+ logits=logits[:, :-1, :],
81
+ labels=inputs["token_ids"][:, 1:],
82
+ )
83
+ )
84
+
85
+
86
+ def make_step(model, inputs, opt_state, key, tx):
87
+ key, new_key = jax.random.split(key)
88
+ loss, grads = compute_loss(model, inputs, key)
89
+ grads = jax.lax.pmean(grads, axis_name="devices")
90
+ updates, opt_state = tx.update(grads, opt_state, model)
91
+ model = eqx.apply_updates(model, updates)
92
+ return loss, model, opt_state, new_key
93
+
94
+
95
+ def make_eval_step(model, inputs):
96
+ """Per-device mean NLL (nats/token), called inside pmap."""
97
+ logits = jax.vmap(functools.partial(model, enable_dropout=False))(inputs)
98
+ return jnp.mean(
99
+ optax.softmax_cross_entropy_with_integer_labels(
100
+ logits=logits[:, :-1, :],
101
+ labels=inputs["token_ids"][:, 1:],
102
+ )
103
+ )
104
+
105
+
106
+ p_make_eval_step = eqx.filter_pmap(make_eval_step)
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # pmap helpers
111
+ # ---------------------------------------------------------------------------
112
+
113
+ def replicate_for_pmap(value, devices):
114
+ mesh = jax.sharding.Mesh(np.asarray(devices), ("devices",))
115
+ sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("devices"))
116
+
117
+ def replicate_leaf(leaf):
118
+ leaf = jnp.asarray(leaf)
119
+ leaf = jnp.broadcast_to(leaf, (len(devices),) + leaf.shape)
120
+ return jax.device_put(leaf, sharding)
121
+
122
+ return jax.tree.map(replicate_leaf, value)
123
+
124
+
125
+ def unreplicate_from_pmap(value):
126
+ return jax.tree.map(lambda leaf: leaf[0], value)
127
+
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # CLI
131
+ # ---------------------------------------------------------------------------
132
+
133
+ def parse_args():
134
+ p = argparse.ArgumentParser(description="Train an Ising spin generator.")
135
+ p.add_argument("--epochs", type=int, default=10)
136
+ p.add_argument("--batch-size", type=int, default=32)
137
+ p.add_argument("--learning-rate", type=float, default=1e-4)
138
+ p.add_argument("--max-train-steps", type=int, default=None)
139
+ p.add_argument("--max-eval-batches", type=int, default=None)
140
+ p.add_argument("--data", type=Path,
141
+ default=Path(__file__).parent / "spins.npy")
142
+ p.add_argument("--output-checkpoint", type=Path, default=None)
143
+ p.add_argument("--seed", type=int, default=5678)
144
+ return p.parse_args()
145
+
146
+
147
+ def main():
148
+ args = parse_args()
149
+ num_devices = jax.device_count()
150
+ print(f"JAX devices: {jax.devices()}")
151
+ assert args.batch_size % num_devices == 0, (
152
+ "batch-size must be a multiple of the number of devices"
153
+ )
154
+
155
+ key = jax.random.PRNGKey(args.seed)
156
+ model_key, train_key = jax.random.split(key)
157
+ model = Generator(config=gen_config, key=model_key)
158
+
159
+ train_tokens, val_tokens = load_ising_data(args.data)
160
+ print(
161
+ f"Train: {len(train_tokens):,} Val: {len(val_tokens):,} "
162
+ f"Seq len: {train_tokens.shape[1]}"
163
+ )
164
+
165
+ tx = optax.chain(
166
+ optax.clip_by_global_norm(1.0),
167
+ optax.adam(learning_rate=args.learning_rate),
168
+ )
169
+ # Mask to float-only leaves so integer bookkeeping fields are excluded.
170
+ tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model))
171
+ opt_state = tx.init(model)
172
+
173
+ p_make_step = eqx.filter_pmap(
174
+ functools.partial(make_step, tx=tx), axis_name="devices"
175
+ )
176
+
177
+ devices = jax.local_devices()
178
+ opt_state = replicate_for_pmap(opt_state, devices)
179
+ model = replicate_for_pmap(model, devices)
180
+ train_key = replicate_for_pmap(train_key, devices)
181
+
182
+ global_step = 0
183
+ for epoch in range(args.epochs):
184
+ rng = np.random.default_rng(args.seed + epoch)
185
+ shuffled = train_tokens[rng.permutation(len(train_tokens))]
186
+
187
+ num_batches = len(shuffled) // args.batch_size
188
+ if args.max_train_steps is not None:
189
+ num_batches = min(num_batches, max(args.max_train_steps - global_step, 0))
190
+
191
+ with tqdm(range(num_batches), unit="steps",
192
+ desc=f"Epoch {epoch + 1}/{args.epochs}") as pbar:
193
+ for step in pbar:
194
+ batch = shuffled[step * args.batch_size : (step + 1) * args.batch_size]
195
+ inputs = prepare_batch(batch, num_devices)
196
+ loss, model, opt_state, train_key = p_make_step(
197
+ model, inputs, opt_state, train_key
198
+ )
199
+ global_step += 1
200
+ pbar.set_postfix(loss=float(np.sum(loss)))
201
+ if args.max_train_steps and global_step >= args.max_train_steps:
202
+ break
203
+
204
+ if args.max_train_steps and global_step >= args.max_train_steps:
205
+ break
206
+
207
+ # ---- validation ----
208
+ num_val = len(val_tokens) // args.batch_size
209
+ if args.max_eval_batches is not None:
210
+ num_val = min(num_val, args.max_eval_batches)
211
+
212
+ val_losses = []
213
+ for step in tqdm(range(num_val), unit="steps", desc="Validation"):
214
+ batch = val_tokens[step * args.batch_size : (step + 1) * args.batch_size]
215
+ inputs = prepare_batch(batch, num_devices)
216
+ val_losses.append(float(np.mean(p_make_eval_step(model, inputs))))
217
+
218
+ print(f"Val NLL: {np.mean(val_losses):.4f} nats/token")
219
+
220
+ if args.output_checkpoint is not None:
221
+ args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
222
+ eqx.tree_serialise_leaves(
223
+ args.output_checkpoint, unreplicate_from_pmap(model)
224
+ )
225
+ print(f"Saved checkpoint β†’ {args.output_checkpoint}")
226
+
227
+
228
+ if __name__ == "__main__":
229
+ main()
vi_train.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # /// script
3
+ # dependencies = [
4
+ # "jax[cuda12]",
5
+ # "equinox",
6
+ # "optax",
7
+ # "tqdm",
8
+ # "jaxtyping",
9
+ # ]
10
+ # ///
11
+ """Variational inference fine-tuning of the Ising Generator.
12
+
13
+ Objective
14
+ ---------
15
+ Minimise the variational free energy
16
+
17
+ F[q] = E_{s~q}[E(s)] βˆ’ T Β· H[q]
18
+ = T Β· E_{s~q}[ E(s)/T + log q(s) ]
19
+
20
+ which equals KL(q βˆ₯ p*) up to the constant log Z, where
21
+ p*(s) ∝ exp(βˆ’E(s)/T) is the Ising Boltzmann distribution.
22
+ As F decreases, the model q approaches the correct physics.
23
+
24
+ Gradient estimator (REINFORCE / score-function)
25
+ ------------------------------------------------
26
+ βˆ‡_ΞΈ F = T Β· E_{s~q}[( E(s)/T + log q(s) βˆ’ b ) Β· βˆ‡_ΞΈ log q(s)]
27
+
28
+ where b = batch-mean reward is a zero-variance control variate.
29
+
30
+ Per training step
31
+ -----------------
32
+ 1. Sample a batch of configs from the current model q_ΞΈ (slow on CPU)
33
+ 2. Compute Ising energy E(s) for each sample (fast, no model)
34
+ 3. Compute log q(s) via a teacher-forced forward pass (fast, one pass)
35
+ 4. Assemble reward R = E/T + log q, subtract baseline, backprop (fast)
36
+
37
+ Speed note
38
+ ----------
39
+ Step 1 dominates on CPU (~5 s / sample on a 32Γ—32 lattice). On GPU it is
40
+ typically 10–100Γ— faster and VI training with batch_size β‰₯ 32 is practical.
41
+ The --checkpoint flag warm-starts from a CE-pretrained model, which dramati-
42
+ cally reduces the number of VI steps needed to converge.
43
+
44
+ Monitoring
45
+ ----------
46
+ At each step the following quantities are logged:
47
+
48
+ e = ⟨E/N⟩ mean energy per spin (converges to data ~βˆ’1.45)
49
+ h = βˆ’βŸ¨log q⟩/N entropy per spin in nats (random init β‰ˆ 0.693)
50
+ f = e βˆ’ TΒ·h Helmholtz free energy per spin (we minimise this)
51
+ |m| = ⟨|Σs / N|⟩ mean absolute magnetisation
52
+ """
53
+
54
+ import argparse
55
+ from pathlib import Path
56
+
57
+ import equinox as eqx
58
+ import jax
59
+ import jax.numpy as jnp
60
+ import numpy as np
61
+ import optax
62
+ from tqdm.auto import tqdm
63
+
64
+ from model import Generator, gen_config, snake_order
65
+ from sample import load_checkpoint, sample_batch
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Physical constants
70
+ # ---------------------------------------------------------------------------
71
+
72
+ J = 1.0
73
+ T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact 2D Ising T_c β‰ˆ 2.2692
74
+
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # Ising energy (pure JAX β€” no model parameters, no gradient needed)
78
+ # ---------------------------------------------------------------------------
79
+
80
+ def ising_energy_per_spin(token_ids: jax.Array) -> jax.Array:
81
+ """Ising energy per spin from a snake-ordered token sequence {0, 1}.
82
+
83
+ Hamiltonian: H = βˆ’J Ξ£_{⟨ij⟩} s_i s_j with periodic boundary conditions.
84
+ Returns a scalar.
85
+ """
86
+ L = gen_config["lattice_size"]
87
+ rows, cols = snake_order(L) # concrete at trace time
88
+ spins = (token_ids * 2 - 1).astype(jnp.float32) # {0,1} β†’ {βˆ’1,+1}
89
+ grid = jnp.zeros((L, L)).at[jnp.asarray(rows), jnp.asarray(cols)].set(spins)
90
+ right = jnp.roll(grid, -1, axis=1)
91
+ down = jnp.roll(grid, -1, axis=0)
92
+ return -J * (grid * right + grid * down).sum() / (L * L)
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # VI loss (REINFORCE with per-batch baseline)
97
+ # ---------------------------------------------------------------------------
98
+
99
+ @eqx.filter_value_and_grad(has_aux=True)
100
+ def compute_vi_loss(
101
+ model,
102
+ token_ids: jax.Array,
103
+ T: float,
104
+ ) -> tuple[jax.Array, dict]:
105
+ """REINFORCE proxy loss for βˆ‡_ΞΈ F[q].
106
+
107
+ The returned scalar is *not* the free energy β€” it is the REINFORCE
108
+ surrogate whose gradient equals βˆ‡_ΞΈ F/T. Use aux["f"] to track F.
109
+
110
+ Parameters
111
+ ----------
112
+ model : Generator
113
+ token_ids : int array (batch, seq_len) samples drawn from q_ΞΈ
114
+ T : target temperature
115
+
116
+ Returns
117
+ -------
118
+ (loss, aux), grads
119
+ aux keys: e, h, f (per spin), |m|, reward_std (variance diagnostic)
120
+ """
121
+ N = gen_config["lattice_size"] ** 2
122
+
123
+ # ── log q(s) via teacher-forced forward pass ─────────────────────────────
124
+ # Disable dropout so we get the exact model log-probability.
125
+ # in_axes=(0, None, None): vmap over batch; broadcast enable_dropout and key.
126
+ logits = jax.vmap(model, in_axes=(0, None, None))(
127
+ {"token_ids": token_ids}, False, None
128
+ ) # (batch, seq_len, state_size)
129
+
130
+ # Ξ£_t log p(s_t | s_{<t}) β€” summed over the sequence axis
131
+ log_q = -optax.softmax_cross_entropy_with_integer_labels(
132
+ logits[:, :-1, :], token_ids[:, 1:]
133
+ ).sum(axis=-1) # (batch,)
134
+
135
+ # ── Ising energies ────────────────────────────────────────────────────────
136
+ # jax.lax.stop_gradient keeps energies out of the autodiff graph.
137
+ energies = jax.lax.stop_gradient(
138
+ jax.vmap(ising_energy_per_spin)(token_ids) # (batch,)
139
+ )
140
+
141
+ # ── REINFORCE reward R = E/T + log q(s) ─────────────────────────────────
142
+ reward = jax.lax.stop_gradient(energies / T + log_q)
143
+ baseline = reward.mean()
144
+
145
+ # Proxy loss: βˆ‡ loss = βˆ’E_q[(R βˆ’ b) Β· βˆ‡ log q] = βˆ‡ F/T
146
+ loss = jnp.mean(jax.lax.stop_gradient(reward - baseline) * (-log_q))
147
+
148
+ # ── Diagnostics (all stop-gradiented; no effect on training) ─────────────
149
+ e = energies.mean() # mean energy per spin
150
+ h = -log_q.mean() / N # entropy per spin (nats)
151
+ f = e - T * h # Helmholtz free energy per spin
152
+ m = jnp.abs((token_ids * 2 - 1).astype(jnp.float32).mean(axis=-1)).mean()
153
+ aux = {
154
+ "e": e,
155
+ "h": h,
156
+ "f": f,
157
+ "|m|": m,
158
+ "reward_std": reward.std(), # REINFORCE variance diagnostic
159
+ }
160
+ return loss, aux
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Single training step (JIT-compiled; does NOT include sampling)
165
+ # ---------------------------------------------------------------------------
166
+
167
+ @eqx.filter_jit
168
+ def vi_step(
169
+ model,
170
+ token_ids: jax.Array,
171
+ opt_state,
172
+ tx,
173
+ T: float,
174
+ ):
175
+ """Compute VI loss + gradient and apply one optimiser update.
176
+
177
+ Sampling is intentionally excluded so you can profile / replace it
178
+ without re-compiling the gradient computation.
179
+ """
180
+ (loss, aux), grads = compute_vi_loss(model, token_ids, T)
181
+ updates, opt_state = tx.update(grads, opt_state, model)
182
+ model = eqx.apply_updates(model, updates)
183
+ return loss, aux, model, opt_state
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # Sampling helper
188
+ # ---------------------------------------------------------------------------
189
+
190
+ _SAMPLE_BATCH = 4 # fixed call-site batch; changing triggers recompilation
191
+
192
+
193
+ def draw_samples(model, n: int, key: jax.Array) -> jax.Array:
194
+ """Sample n configurations from the model in fixed-size batches.
195
+
196
+ Returns a jnp int32 array of shape (n, LΒ²).
197
+ """
198
+ all_tokens = []
199
+ n_calls = -(-n // _SAMPLE_BATCH) # ceiling division
200
+ for _ in range(n_calls):
201
+ key, subkey = jax.random.split(key)
202
+ all_tokens.append(np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey)))
203
+ return jnp.asarray(np.concatenate(all_tokens)[:n])
204
+
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # CLI
208
+ # ---------------------------------------------------------------------------
209
+
210
+ def parse_args():
211
+ p = argparse.ArgumentParser(
212
+ description="Variational inference fine-tuning of an Ising Generator.",
213
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
214
+ )
215
+ p.add_argument("--checkpoint", type=Path, default=None,
216
+ help="Warm-start from this .eqx file (strongly recommended).")
217
+ p.add_argument("--output-checkpoint", type=Path, default=None,
218
+ help="Save final model to this path.")
219
+ p.add_argument("--num-steps", type=int, default=200)
220
+ p.add_argument("--batch-size", type=int, default=16,
221
+ help="Configurations sampled from q per gradient step.")
222
+ p.add_argument("--learning-rate", type=float, default=1e-4)
223
+ p.add_argument("--temperature", type=float, default=T_C,
224
+ help=f"Target Boltzmann temperature (default: T_c β‰ˆ {T_C:.4f}).")
225
+ p.add_argument("--log-every", type=int, default=1)
226
+ p.add_argument("--seed", type=int, default=0)
227
+ return p.parse_args()
228
+
229
+
230
+ def main():
231
+ args = parse_args()
232
+ T = args.temperature
233
+ key = jax.random.PRNGKey(args.seed)
234
+
235
+ # ── Model ─────────────────────────────────────────────────────────────────
236
+ if args.checkpoint is not None:
237
+ print(f"Loading checkpoint from {args.checkpoint} …")
238
+ model = load_checkpoint(args.checkpoint)
239
+ else:
240
+ print("Initialising model from scratch.")
241
+ print(" Tip: use --checkpoint to warm-start from a CE-pretrained model.")
242
+ key, model_key = jax.random.split(key)
243
+ model = Generator(config=gen_config, key=model_key)
244
+
245
+ # ── Optimiser ─────────────────────────────────────────────────────────────
246
+ tx = optax.chain(
247
+ optax.clip_by_global_norm(1.0),
248
+ optax.adam(learning_rate=args.learning_rate),
249
+ )
250
+ tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model))
251
+ opt_state = tx.init(model)
252
+
253
+ L = gen_config["lattice_size"]
254
+ print(f"\nVI training | steps={args.num_steps} "
255
+ f"batch={args.batch_size} T={T:.4f} lr={args.learning_rate} L={L}")
256
+ print(" columns: e = ⟨E/N⟩ h = βˆ’βŸ¨log q⟩/N "
257
+ "f = eβˆ’TΒ·h (minimised) |m| = mean |magnetisation|\n")
258
+
259
+ # ── Training loop ─────────────────────────────────────────────────────────
260
+ with tqdm(range(args.num_steps), unit="steps") as pbar:
261
+ for step in pbar:
262
+ # 1. Sample from current model (bottleneck on CPU)
263
+ key, sample_key = jax.random.split(key)
264
+ token_ids = draw_samples(model, args.batch_size, sample_key)
265
+
266
+ # 2. VI gradient step (JIT-compiled teacher-forced forward pass)
267
+ loss, aux, model, opt_state = vi_step(
268
+ model, token_ids, opt_state, tx, T
269
+ )
270
+
271
+ if step % args.log_every == 0:
272
+ pbar.set_postfix(
273
+ e = f"{float(aux['e']):.4f}",
274
+ h = f"{float(aux['h']):.4f}",
275
+ f = f"{float(aux['f']):.4f}",
276
+ m = f"{float(aux['|m|']):.3f}",
277
+ Rstd= f"{float(aux['reward_std']):.3f}",
278
+ )
279
+
280
+ # ── Save ──────────────────────────────────────────────────────────────────
281
+ if args.output_checkpoint is not None:
282
+ args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
283
+ eqx.tree_serialise_leaves(args.output_checkpoint, model)
284
+ print(f"\nSaved checkpoint β†’ {args.output_checkpoint}")
285
+
286
+
287
+ if __name__ == "__main__":
288
+ main()