Spaces:
Sleeping
Sleeping
Initial upload: model, training scripts, Gradio app, data
Browse files- app.py +287 -0
- eval.py +386 -0
- examples.png +0 -0
- ising.py +20 -0
- main.py +26 -0
- metadata.json +17 -0
- model.py +362 -0
- requirements.txt +9 -0
- sample.py +208 -0
- samples-2-epoch.png +0 -0
- spins.npy +3 -0
- spins_test.npy +3 -0
- train.py +229 -0
- vi_train.py +288 -0
app.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shlex
|
| 2 |
+
import subprocess
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use("Agg")
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
# Paths
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
OUT = Path("outputs")
|
| 17 |
+
CE_CKPT = OUT / "ce_checkpoint.eqx"
|
| 18 |
+
VI_CKPT = OUT / "vi_checkpoint.eqx"
|
| 19 |
+
CE_SMPL = OUT / "samples_ce.npy"
|
| 20 |
+
VI_SMPL = OUT / "samples_vi.npy"
|
| 21 |
+
TRAIN_DATA = Path("spins.npy")
|
| 22 |
+
TEST_DATA = Path("spins_test.npy")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Subprocess helpers
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
def _stream(command: list[str]):
|
| 30 |
+
"""Run a command and yield log lines in real time."""
|
| 31 |
+
log = ["$ " + " ".join(shlex.quote(p) for p in command), ""]
|
| 32 |
+
yield "\n".join(log)
|
| 33 |
+
proc = subprocess.Popen(
|
| 34 |
+
command,
|
| 35 |
+
stdout=subprocess.PIPE,
|
| 36 |
+
stderr=subprocess.STDOUT,
|
| 37 |
+
text=True,
|
| 38 |
+
bufsize=1,
|
| 39 |
+
)
|
| 40 |
+
assert proc.stdout is not None
|
| 41 |
+
for line in proc.stdout:
|
| 42 |
+
log.append(line.rstrip())
|
| 43 |
+
yield "\n".join(log[-300:])
|
| 44 |
+
rc = proc.wait()
|
| 45 |
+
log += ["", f"β exited {rc} β"]
|
| 46 |
+
yield "\n".join(log[-300:])
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# Sample grid figure
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def _samples_figure(path: Path, title: str, n: int = 16) -> plt.Figure | None:
|
| 54 |
+
if not path.exists():
|
| 55 |
+
return None
|
| 56 |
+
grids = np.load(path).astype(np.float32)[:n] # (N, L, L), values Β±1
|
| 57 |
+
cols = min(8, len(grids))
|
| 58 |
+
rows = (len(grids) + cols - 1) // cols
|
| 59 |
+
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.4, rows * 1.4))
|
| 60 |
+
axes = np.array(axes).reshape(-1)
|
| 61 |
+
mags = grids.mean(axis=(1, 2))
|
| 62 |
+
for i, ax in enumerate(axes):
|
| 63 |
+
if i < len(grids):
|
| 64 |
+
ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest")
|
| 65 |
+
ax.set_title(f"m={mags[i]:.2f}", fontsize=6)
|
| 66 |
+
ax.axis("off")
|
| 67 |
+
fig.suptitle(title, fontsize=9)
|
| 68 |
+
plt.tight_layout()
|
| 69 |
+
return fig
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Tab 1 β Cross-entropy training
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def run_ce(mode, epochs, batch_size, lr, max_steps):
|
| 77 |
+
OUT.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
cmd = [
|
| 79 |
+
sys.executable, "train.py",
|
| 80 |
+
"--data", str(TRAIN_DATA),
|
| 81 |
+
"--batch-size", str(int(batch_size)),
|
| 82 |
+
"--learning-rate", str(lr),
|
| 83 |
+
"--output-checkpoint", str(CE_CKPT),
|
| 84 |
+
]
|
| 85 |
+
if mode == "Smoke":
|
| 86 |
+
cmd += ["--epochs", "1", "--max-train-steps", "5", "--max-eval-batches", "2"]
|
| 87 |
+
else:
|
| 88 |
+
cmd += ["--epochs", str(int(epochs))]
|
| 89 |
+
if int(max_steps) > 0:
|
| 90 |
+
cmd += ["--max-train-steps", str(int(max_steps))]
|
| 91 |
+
for log in _stream(cmd):
|
| 92 |
+
yield log, None
|
| 93 |
+
ckpt = str(CE_CKPT) if CE_CKPT.exists() else None
|
| 94 |
+
yield log, ckpt
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
# Tab 2 β Variational inference fine-tuning
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
|
| 101 |
+
def run_vi(mode, steps, batch_size, lr, warm_start):
|
| 102 |
+
OUT.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
if warm_start and not CE_CKPT.exists():
|
| 104 |
+
yield "β CE checkpoint not found. Run CE training first, or uncheck warm-start.", None
|
| 105 |
+
return
|
| 106 |
+
cmd = [
|
| 107 |
+
sys.executable, "vi_train.py",
|
| 108 |
+
"--batch-size", str(int(batch_size)),
|
| 109 |
+
"--learning-rate", str(lr),
|
| 110 |
+
"--output-checkpoint", str(VI_CKPT),
|
| 111 |
+
]
|
| 112 |
+
if warm_start and CE_CKPT.exists():
|
| 113 |
+
cmd += ["--checkpoint", str(CE_CKPT)]
|
| 114 |
+
if mode == "Smoke":
|
| 115 |
+
cmd += ["--num-steps", "3", "--log-every", "1"]
|
| 116 |
+
else:
|
| 117 |
+
cmd += ["--num-steps", str(int(steps))]
|
| 118 |
+
for log in _stream(cmd):
|
| 119 |
+
yield log, None
|
| 120 |
+
ckpt = str(VI_CKPT) if VI_CKPT.exists() else None
|
| 121 |
+
yield log, ckpt
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
# Tab 3 β Sample & Eval
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
|
| 128 |
+
def run_eval(which, num_samples, seed):
|
| 129 |
+
OUT.mkdir(parents=True, exist_ok=True)
|
| 130 |
+
log_lines = []
|
| 131 |
+
|
| 132 |
+
def emit(msg=""):
|
| 133 |
+
log_lines.append(msg)
|
| 134 |
+
return (
|
| 135 |
+
"\n".join(log_lines[-200:]),
|
| 136 |
+
None, None, # CE figure, VI figure
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
run_ce_ = which in ("CE", "Both")
|
| 140 |
+
run_vi_ = which in ("VI", "Both")
|
| 141 |
+
|
| 142 |
+
if run_ce_ and not CE_CKPT.exists():
|
| 143 |
+
yield emit("β CE checkpoint not found. Run CE training first.")
|
| 144 |
+
return
|
| 145 |
+
if run_vi_ and not VI_CKPT.exists():
|
| 146 |
+
yield emit("β VI checkpoint not found. Run VI training first.")
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
# ββ Generate samples βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
for ckpt, out_path, label in [
|
| 151 |
+
(CE_CKPT, CE_SMPL, "CE"),
|
| 152 |
+
(VI_CKPT, VI_SMPL, "VI"),
|
| 153 |
+
]:
|
| 154 |
+
if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
|
| 155 |
+
continue
|
| 156 |
+
log_lines.append(f"\nββ Generating {num_samples} {label} samples ββ")
|
| 157 |
+
yield "\n".join(log_lines[-200:]), None, None
|
| 158 |
+
cmd = [
|
| 159 |
+
sys.executable, "sample.py",
|
| 160 |
+
"--checkpoint", str(ckpt),
|
| 161 |
+
"--num-samples", str(int(num_samples)),
|
| 162 |
+
"--output", str(out_path),
|
| 163 |
+
"--seed", str(int(seed)),
|
| 164 |
+
]
|
| 165 |
+
for chunk in _stream(cmd):
|
| 166 |
+
log_lines[-1:] = chunk.splitlines()[-10:]
|
| 167 |
+
yield "\n".join(log_lines[-200:]), None, None
|
| 168 |
+
|
| 169 |
+
# ββ Run eval ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 170 |
+
for ckpt, smpl, label in [
|
| 171 |
+
(CE_CKPT, CE_SMPL, "CE"),
|
| 172 |
+
(VI_CKPT, VI_SMPL, "VI"),
|
| 173 |
+
]:
|
| 174 |
+
if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
|
| 175 |
+
continue
|
| 176 |
+
log_lines.append(f"\nββ Evaluating {label} model ββ")
|
| 177 |
+
yield "\n".join(log_lines[-200:]), None, None
|
| 178 |
+
cmd = [
|
| 179 |
+
sys.executable, "eval.py",
|
| 180 |
+
"--checkpoint", str(ckpt),
|
| 181 |
+
"--test-data", str(TEST_DATA),
|
| 182 |
+
"--num-samples", str(int(num_samples)),
|
| 183 |
+
"--samples-file",str(smpl),
|
| 184 |
+
"--seed", str(int(seed)),
|
| 185 |
+
]
|
| 186 |
+
for chunk in _stream(cmd):
|
| 187 |
+
log_lines[-1:] = chunk.splitlines()[-20:]
|
| 188 |
+
yield "\n".join(log_lines[-200:]), None, None
|
| 189 |
+
|
| 190 |
+
# ββ Build figures ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 191 |
+
ce_fig = _samples_figure(CE_SMPL, "CE samples") if run_ce_ else None
|
| 192 |
+
vi_fig = _samples_figure(VI_SMPL, "VI samples") if run_vi_ else None
|
| 193 |
+
yield "\n".join(log_lines[-200:]), ce_fig, vi_fig
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ---------------------------------------------------------------------------
|
| 197 |
+
# Gradio UI
|
| 198 |
+
# ---------------------------------------------------------------------------
|
| 199 |
+
|
| 200 |
+
with gr.Blocks(title="Ising Transformer", theme=gr.themes.Soft()) as demo:
|
| 201 |
+
gr.Markdown(
|
| 202 |
+
"# 2D Ising Transformer\n"
|
| 203 |
+
"Autoregressive transformer trained on the 2D Ising model at the critical "
|
| 204 |
+
"temperature T_c β 2.27. "
|
| 205 |
+
"Run **CE training** first, optionally fine-tune with **Variational Inference**, "
|
| 206 |
+
"then **Sample & Eval** to compare both against the held-out test set."
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
with gr.Tabs():
|
| 210 |
+
|
| 211 |
+
# ββ Tab 1: CE training ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
+
with gr.Tab("Cross-Entropy Training"):
|
| 213 |
+
gr.Markdown(
|
| 214 |
+
"Trains the model to maximise `log q(s)` on the training spin "
|
| 215 |
+
"configurations (teacher forcing, causal attention). "
|
| 216 |
+
"A *Smoke* run does 5 steps to verify everything compiles."
|
| 217 |
+
)
|
| 218 |
+
with gr.Row():
|
| 219 |
+
ce_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode")
|
| 220 |
+
ce_epoch = gr.Number(value=10, precision=0, minimum=1, label="Epochs")
|
| 221 |
+
ce_bs = gr.Number(value=32, precision=0, minimum=1, label="Batch size")
|
| 222 |
+
with gr.Row():
|
| 223 |
+
ce_lr = gr.Number(value=1e-4, label="Learning rate")
|
| 224 |
+
ce_steps = gr.Number(value=0, precision=0, minimum=0, label="Max steps (0 = no cap)")
|
| 225 |
+
ce_run = gr.Button("Run CE Training", variant="primary")
|
| 226 |
+
ce_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
|
| 227 |
+
ce_ckpt = gr.File(label="Checkpoint")
|
| 228 |
+
ce_run.click(
|
| 229 |
+
run_ce,
|
| 230 |
+
inputs=[ce_mode, ce_epoch, ce_bs, ce_lr, ce_steps],
|
| 231 |
+
outputs=[ce_logs, ce_ckpt],
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# ββ Tab 2: VI fine-tuning βββββββββββββββββββββββββββββββββββββββββ
|
| 235 |
+
with gr.Tab("Variational Inference Fine-tuning"):
|
| 236 |
+
gr.Markdown(
|
| 237 |
+
"Minimises the variational free energy `F = β¨E(s)β© β TΒ·H[q]` using "
|
| 238 |
+
"the REINFORCE gradient estimator. Warm-starting from the CE "
|
| 239 |
+
"checkpoint is strongly recommended. "
|
| 240 |
+
"A *Smoke* run does 3 steps."
|
| 241 |
+
)
|
| 242 |
+
with gr.Row():
|
| 243 |
+
vi_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode")
|
| 244 |
+
vi_steps = gr.Number(value=200, precision=0, minimum=1, label="Steps")
|
| 245 |
+
vi_bs = gr.Number(value=16, precision=0, minimum=1, label="Batch size")
|
| 246 |
+
with gr.Row():
|
| 247 |
+
vi_lr = gr.Number(value=1e-4, label="Learning rate")
|
| 248 |
+
vi_warm = gr.Checkbox(value=True, label="Warm-start from CE checkpoint")
|
| 249 |
+
vi_run = gr.Button("Run VI Fine-tuning", variant="primary")
|
| 250 |
+
vi_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
|
| 251 |
+
vi_ckpt = gr.File(label="Checkpoint")
|
| 252 |
+
vi_run.click(
|
| 253 |
+
run_vi,
|
| 254 |
+
inputs=[vi_mode, vi_steps, vi_bs, vi_lr, vi_warm],
|
| 255 |
+
outputs=[vi_logs, vi_ckpt],
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# ββ Tab 3: Sample & Eval ββββββββββββββββββββββββββββββββββββββββββ
|
| 259 |
+
with gr.Tab("Sample & Eval"):
|
| 260 |
+
gr.Markdown(
|
| 261 |
+
"Generates spin configurations from the selected model(s), then runs "
|
| 262 |
+
"the physical-observable evaluation against `spins_test.npy` "
|
| 263 |
+
"(a held-out set, never seen during training).\n\n"
|
| 264 |
+
"Features compared: magnetisation, energy, two-point correlations, "
|
| 265 |
+
"cluster statistics. Distance reported as **Mahalanobis D** in the "
|
| 266 |
+
"decorrelated feature space."
|
| 267 |
+
)
|
| 268 |
+
with gr.Row():
|
| 269 |
+
ev_which = gr.Radio(
|
| 270 |
+
["CE", "VI", "Both"], value="Both", label="Model(s) to evaluate"
|
| 271 |
+
)
|
| 272 |
+
ev_n = gr.Number(value=64, precision=0, minimum=4, label="Num samples")
|
| 273 |
+
ev_seed = gr.Number(value=0, precision=0, label="Seed")
|
| 274 |
+
ev_run = gr.Button("Run Sample & Eval", variant="primary")
|
| 275 |
+
ev_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
|
| 276 |
+
with gr.Row():
|
| 277 |
+
ev_ce_img = gr.Plot(label="CE samples")
|
| 278 |
+
ev_vi_img = gr.Plot(label="VI samples")
|
| 279 |
+
ev_run.click(
|
| 280 |
+
run_eval,
|
| 281 |
+
inputs=[ev_which, ev_n, ev_seed],
|
| 282 |
+
outputs=[ev_logs, ev_ce_img, ev_vi_img],
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
demo.queue(default_concurrency_limit=1).launch()
|
eval.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# /// script
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "jax[cuda12]",
|
| 5 |
+
# "equinox",
|
| 6 |
+
# "scipy",
|
| 7 |
+
# "jaxtyping",
|
| 8 |
+
# ]
|
| 9 |
+
# ///
|
| 10 |
+
"""Evaluate a trained Generator against held-out test samples.
|
| 11 |
+
|
| 12 |
+
For each configuration we compute an 11-dimensional feature vector of physical
|
| 13 |
+
observables. The Mahalanobis distance between the real and generated feature
|
| 14 |
+
distributions gives a single scalar measure of model quality.
|
| 15 |
+
|
| 16 |
+
Per-sample feature vector
|
| 17 |
+
--------------------------
|
| 18 |
+
m, m^2, |m| magnetisation and its moments
|
| 19 |
+
e, e^2 nearest-neighbour energy per spin (periodic BC)
|
| 20 |
+
C(1..8) connected two-point correlation at r = 1, 2, 4, 8
|
| 21 |
+
s_mean/N mean cluster size (4-connected, open BC)
|
| 22 |
+
s_max/N largest cluster size
|
| 23 |
+
|
| 24 |
+
Ensemble statistics (printed for reference, not part of Mahalanobis)
|
| 25 |
+
----------------------------------------------------------------------
|
| 26 |
+
chi = N Β· Var(m) / T magnetic susceptibility
|
| 27 |
+
C_v = N Β· Var(e) / TΒ² specific heat
|
| 28 |
+
U4 = 1 β <m^4>/(3<m^2>^2) Binder cumulant
|
| 29 |
+
β 2/3 in ordered phase
|
| 30 |
+
β 0 in disordered phase
|
| 31 |
+
β 0.47 at T_c for 2D Ising (Lββ)
|
| 32 |
+
|
| 33 |
+
Distance
|
| 34 |
+
--------
|
| 35 |
+
D = sqrt( ΞΞΌ^T Ξ£_real^{-1} ΞΞΌ )
|
| 36 |
+
where ΞΞΌ = ΞΌ_gen β ΞΌ_real and Ξ£_real is the sample covariance of the
|
| 37 |
+
real test features. Per-feature z-scores ΞΞΌ_i / Ο_real_i are also
|
| 38 |
+
reported so you can see which observables deviate most.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import argparse
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
|
| 44 |
+
import numpy as np
|
| 45 |
+
import scipy.ndimage
|
| 46 |
+
import jax
|
| 47 |
+
from tqdm.auto import tqdm
|
| 48 |
+
|
| 49 |
+
from model import gen_config
|
| 50 |
+
from sample import load_checkpoint, sample_batch, tokens_to_grids
|
| 51 |
+
from train import load_ising_data
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Physical constants
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
J = 1.0
|
| 59 |
+
T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact: 2J / ln(1+β2) β 2.2692
|
| 60 |
+
|
| 61 |
+
FEATURE_NAMES = [
|
| 62 |
+
"m", "m^2", "|m|",
|
| 63 |
+
"e", "e^2",
|
| 64 |
+
"C(r=1)", "C(r=2)", "C(r=4)", "C(r=8)",
|
| 65 |
+
"s_mean/N", "s_max/N",
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# Per-sample observables
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
def energy_per_spin(grid: np.ndarray) -> float:
|
| 74 |
+
"""Nearest-neighbour energy density with periodic boundary conditions.
|
| 75 |
+
|
| 76 |
+
E/N = βJ/N Β· Ξ£_{β¨ijβ©} s_i s_j
|
| 77 |
+
Each bond counted once via right- and down-shifts.
|
| 78 |
+
"""
|
| 79 |
+
right = np.roll(grid, -1, axis=1)
|
| 80 |
+
down = np.roll(grid, -1, axis=0)
|
| 81 |
+
return float(-J * (grid * right + grid * down).sum() / grid.size)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def connected_correlations(
|
| 85 |
+
grid: np.ndarray,
|
| 86 |
+
distances: tuple[int, ...] = (1, 2, 4, 8),
|
| 87 |
+
) -> np.ndarray:
|
| 88 |
+
"""Isotropic connected two-point function C(r) = Β½[<s_x s_{x+r}> + <s_y s_{y+r}>] - <s>Β².
|
| 89 |
+
|
| 90 |
+
Averaged over both spatial directions and all origin sites using
|
| 91 |
+
periodic boundary conditions.
|
| 92 |
+
"""
|
| 93 |
+
m = float(grid.mean())
|
| 94 |
+
corr = []
|
| 95 |
+
for r in distances:
|
| 96 |
+
cx = float((grid * np.roll(grid, r, axis=1)).mean())
|
| 97 |
+
cy = float((grid * np.roll(grid, r, axis=0)).mean())
|
| 98 |
+
corr.append((cx + cy) / 2.0 - m ** 2)
|
| 99 |
+
return np.array(corr, dtype=np.float64)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def cluster_stats(grid: np.ndarray) -> tuple[float, float]:
|
| 103 |
+
"""Mean and maximum cluster size for both spin species.
|
| 104 |
+
|
| 105 |
+
Uses 4-connectivity (no diagonals) and open boundary conditions.
|
| 106 |
+
Returns sizes normalised by the total number of spins so the result
|
| 107 |
+
is independent of lattice size.
|
| 108 |
+
|
| 109 |
+
Note: open BC means edge-spanning clusters are split at the boundary;
|
| 110 |
+
this is applied consistently to both real and generated samples so
|
| 111 |
+
systematic bias cancels in the Mahalanobis comparison.
|
| 112 |
+
"""
|
| 113 |
+
N = grid.size
|
| 114 |
+
all_sizes: list[np.ndarray] = []
|
| 115 |
+
for spin in (1, -1):
|
| 116 |
+
labeled, n_labels = scipy.ndimage.label(grid == spin)
|
| 117 |
+
if n_labels > 0:
|
| 118 |
+
# bincount index 0 is background; skip it
|
| 119 |
+
all_sizes.append(np.bincount(labeled.ravel())[1:])
|
| 120 |
+
if not all_sizes:
|
| 121 |
+
return 0.0, 0.0
|
| 122 |
+
sizes = np.concatenate(all_sizes).astype(np.float64)
|
| 123 |
+
return float(sizes.mean()) / N, float(sizes.max()) / N
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def compute_features(grid: np.ndarray) -> np.ndarray:
|
| 127 |
+
"""Return the 11-D feature vector for a single Β±1 grid of shape (L, L)."""
|
| 128 |
+
m = float(grid.mean())
|
| 129 |
+
e = energy_per_spin(grid)
|
| 130 |
+
cr = connected_correlations(grid)
|
| 131 |
+
s_mean, s_max = cluster_stats(grid)
|
| 132 |
+
return np.array(
|
| 133 |
+
[m, m ** 2, abs(m), e, e ** 2, *cr, s_mean, s_max],
|
| 134 |
+
dtype=np.float64,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def compute_feature_matrix(grids: np.ndarray, desc: str = "features") -> np.ndarray:
|
| 139 |
+
"""Compute the (N, 11) feature matrix for a batch of grids."""
|
| 140 |
+
return np.stack(
|
| 141 |
+
[compute_features(grids[i])
|
| 142 |
+
for i in tqdm(range(len(grids)), desc=desc, leave=False)]
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
# Ensemble statistics
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
def ensemble_stats(X: np.ndarray, T: float = T_C) -> dict[str, float]:
|
| 151 |
+
"""Derive thermodynamic ensemble statistics from a feature matrix.
|
| 152 |
+
|
| 153 |
+
Arguments
|
| 154 |
+
---------
|
| 155 |
+
X : (N, 11) feature matrix from ``compute_feature_matrix``.
|
| 156 |
+
T : temperature used for Ο and C_v normalisation.
|
| 157 |
+
"""
|
| 158 |
+
L = gen_config["lattice_size"]
|
| 159 |
+
N = L * L
|
| 160 |
+
|
| 161 |
+
m = X[:, FEATURE_NAMES.index("m")]
|
| 162 |
+
m2 = X[:, FEATURE_NAMES.index("m^2")]
|
| 163 |
+
m4 = m ** 4
|
| 164 |
+
e = X[:, FEATURE_NAMES.index("e")]
|
| 165 |
+
|
| 166 |
+
chi = N * float(m.var()) / T
|
| 167 |
+
Cv = N * float(e.var()) / T ** 2
|
| 168 |
+
binder = float(1.0 - m4.mean() / (3.0 * m2.mean() ** 2)) if m2.mean() > 0 else float("nan")
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
"<|m|>": float(np.abs(m).mean()),
|
| 172 |
+
"chi": chi,
|
| 173 |
+
"C_v": Cv,
|
| 174 |
+
"U4": binder,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ---------------------------------------------------------------------------
|
| 179 |
+
# Mahalanobis distance
|
| 180 |
+
# ---------------------------------------------------------------------------
|
| 181 |
+
|
| 182 |
+
def mahalanobis_distance(
|
| 183 |
+
X_ref: np.ndarray,
|
| 184 |
+
X_query: np.ndarray,
|
| 185 |
+
reg: float = 1e-6,
|
| 186 |
+
) -> tuple[float, np.ndarray]:
|
| 187 |
+
"""Mahalanobis distance of the query-mean from the reference distribution.
|
| 188 |
+
|
| 189 |
+
D = sqrt( ΞΞΌ^T Ξ£_ref^{-1} ΞΞΌ )
|
| 190 |
+
|
| 191 |
+
Also returns per-feature z-scores z_i = ΞΞΌ_i / Ο_ref_i,
|
| 192 |
+
where Ο_ref_i = sqrt(Ξ£_ref[i,i]). |z_i| > 1 indicates a feature
|
| 193 |
+
whose mean differs by more than one real-sample standard deviation.
|
| 194 |
+
|
| 195 |
+
Parameters
|
| 196 |
+
----------
|
| 197 |
+
X_ref : (N, d) real / reference feature matrix
|
| 198 |
+
X_query : (M, d) generated / query feature matrix
|
| 199 |
+
reg : diagonal regularisation added to Ξ£_ref before inversion
|
| 200 |
+
"""
|
| 201 |
+
mu_ref = X_ref.mean(axis=0)
|
| 202 |
+
mu_query = X_query.mean(axis=0)
|
| 203 |
+
cov = np.cov(X_ref.T) + reg * np.eye(X_ref.shape[1])
|
| 204 |
+
cov_inv = np.linalg.inv(cov)
|
| 205 |
+
delta = mu_query - mu_ref
|
| 206 |
+
D = float(np.sqrt(max(0.0, delta @ cov_inv @ delta)))
|
| 207 |
+
z_scores = delta / np.sqrt(np.diag(cov))
|
| 208 |
+
return D, z_scores
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
# Reporting
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
|
| 215 |
+
def print_feature_table(X_real: np.ndarray, X_gen: np.ndarray) -> None:
|
| 216 |
+
mu_r = X_real.mean(axis=0)
|
| 217 |
+
sd_r = X_real.std(axis=0)
|
| 218 |
+
mu_g = X_gen.mean(axis=0)
|
| 219 |
+
sd_g = X_gen.std(axis=0)
|
| 220 |
+
|
| 221 |
+
col = 13
|
| 222 |
+
hdr = (f" {'Feature':<11} {'Real mean':>{col}} {'Real std':>{col}}"
|
| 223 |
+
f" {'Gen mean':>{col}} {'Gen std':>{col}} {'z-score':>8}")
|
| 224 |
+
print(hdr)
|
| 225 |
+
print(" " + "β" * (len(hdr) - 2))
|
| 226 |
+
for name, mr, sr, mg, sg in zip(FEATURE_NAMES, mu_r, sd_r, mu_g, sd_g):
|
| 227 |
+
z = (mg - mr) / (sr + 1e-12)
|
| 228 |
+
flag = " <" if abs(z) > 1.0 else ""
|
| 229 |
+
print(f" {name:<11} {mr:>{col}.4f} {sr:>{col}.4f}"
|
| 230 |
+
f" {mg:>{col}.4f} {sg:>{col}.4f} {z:>+8.3f}{flag}")
|
| 231 |
+
print()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def print_ensemble_table(stats_real: dict, stats_gen: dict) -> None:
|
| 235 |
+
labels = {
|
| 236 |
+
"<|m|>": "mean |m|",
|
| 237 |
+
"chi": "chi (susceptibility)",
|
| 238 |
+
"C_v": "C_v (specific heat)",
|
| 239 |
+
"U4": "U4 (Binder cumulant)",
|
| 240 |
+
}
|
| 241 |
+
print(f" {'Observable':<26} {'Real':>10} {'Generated':>10}")
|
| 242 |
+
print(" " + "β" * 50)
|
| 243 |
+
for key, label in labels.items():
|
| 244 |
+
r = stats_real[key]
|
| 245 |
+
g = stats_gen[key]
|
| 246 |
+
print(f" {label:<26} {r:>10.4f} {g:>10.4f}")
|
| 247 |
+
print()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
# CLI
|
| 252 |
+
# ---------------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
_SAMPLE_BATCH = 4 # fixed vmapped batch; changing triggers recompilation
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def generate_grids(model, n: int, key: jax.Array, L: int) -> np.ndarray:
|
| 258 |
+
"""Sample n grids in batches of _SAMPLE_BATCH with a progress bar.
|
| 259 |
+
|
| 260 |
+
Using a fixed batch size means only one JIT compilation happens regardless
|
| 261 |
+
of n. The final partial batch is padded then trimmed.
|
| 262 |
+
"""
|
| 263 |
+
batches = []
|
| 264 |
+
n_full, remainder = divmod(n, _SAMPLE_BATCH)
|
| 265 |
+
n_batches = n_full + (1 if remainder else 0)
|
| 266 |
+
|
| 267 |
+
with tqdm(total=n, unit="samples", desc="Sampling") as pbar:
|
| 268 |
+
for i in range(n_batches):
|
| 269 |
+
key, subkey = jax.random.split(key)
|
| 270 |
+
tokens = np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey))
|
| 271 |
+
batches.append(tokens)
|
| 272 |
+
pbar.update(min(_SAMPLE_BATCH, n - i * _SAMPLE_BATCH))
|
| 273 |
+
|
| 274 |
+
return tokens_to_grids(np.concatenate(batches)[:n], L)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def load_test_grids(
|
| 278 |
+
test_data: Path | None,
|
| 279 |
+
data: Path,
|
| 280 |
+
n: int,
|
| 281 |
+
L: int,
|
| 282 |
+
rng: np.random.Generator,
|
| 283 |
+
) -> np.ndarray:
|
| 284 |
+
"""Load real test grids, preferring a dedicated test file over the val split.
|
| 285 |
+
|
| 286 |
+
Parameters
|
| 287 |
+
----------
|
| 288 |
+
test_data : optional path to a standalone test .npy file (N, L, L) int8 {-1,+1}
|
| 289 |
+
data : path to the main spins.npy (used only if test_data is None)
|
| 290 |
+
"""
|
| 291 |
+
if test_data is not None:
|
| 292 |
+
spins = np.load(test_data) # (N, L, L) int8
|
| 293 |
+
tokens = (spins.astype(np.int32) + 1) // 2 # β {0, 1}
|
| 294 |
+
rows, cols = snake_order(L)
|
| 295 |
+
tokens = tokens[:, rows, cols] # (N, LΒ²)
|
| 296 |
+
else:
|
| 297 |
+
_, tokens = load_ising_data(data) # val split of spins.npy
|
| 298 |
+
|
| 299 |
+
n = min(n, len(tokens))
|
| 300 |
+
idx = rng.choice(len(tokens), size=n, replace=False)
|
| 301 |
+
return tokens_to_grids(tokens[idx], L) # (n, L, L), values Β±1
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def parse_args():
|
| 305 |
+
p = argparse.ArgumentParser(
|
| 306 |
+
description="Compare generated vs real Ising samples via physical observables."
|
| 307 |
+
)
|
| 308 |
+
p.add_argument("--checkpoint", type=Path, required=True,
|
| 309 |
+
help="Path to the .eqx checkpoint file.")
|
| 310 |
+
p.add_argument("--data", type=Path,
|
| 311 |
+
default=Path(__file__).parent / "spins.npy",
|
| 312 |
+
help="Path to spins.npy (default: ./spins.npy). "
|
| 313 |
+
"Used only if --test-data is not provided.")
|
| 314 |
+
p.add_argument("--test-data", type=Path,
|
| 315 |
+
default=Path(__file__).parent / "spins_test.npy",
|
| 316 |
+
help="Dedicated held-out test set (.npy, NΓLΓL int8 {-1,+1}). "
|
| 317 |
+
"Takes priority over the val split of --data.")
|
| 318 |
+
p.add_argument("--num-samples", type=int, default=50,
|
| 319 |
+
help="Number of samples to compare (default: 50).")
|
| 320 |
+
p.add_argument("--samples-file", type=Path, default=None,
|
| 321 |
+
help="Optional .npy of pre-generated {-1,+1} grids (N,L,L) "
|
| 322 |
+
"from 'sample.py --output'. Skips generation entirely.")
|
| 323 |
+
p.add_argument("--seed", type=int, default=0)
|
| 324 |
+
return p.parse_args()
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def main():
|
| 328 |
+
args = parse_args()
|
| 329 |
+
L = gen_config["lattice_size"]
|
| 330 |
+
rng = np.random.default_rng(args.seed)
|
| 331 |
+
|
| 332 |
+
# ββ Real samples (test split) βββββββββββββββββββββββββββββββββββββββββββββ
|
| 333 |
+
# Prefer spins_test.npy; fall back to val split of spins.npy.
|
| 334 |
+
test_path = args.test_data if (args.test_data and args.test_data.exists()) else None
|
| 335 |
+
if test_path:
|
| 336 |
+
print(f"Loading test data from {test_path} β¦")
|
| 337 |
+
else:
|
| 338 |
+
print("Loading test data from val split of spins.npy β¦")
|
| 339 |
+
n = args.num_samples
|
| 340 |
+
real_grids = load_test_grids(test_path, args.data, n, L, rng)
|
| 341 |
+
n = len(real_grids) # may be capped by dataset size
|
| 342 |
+
|
| 343 |
+
# ββ Generated samples βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 344 |
+
if args.samples_file is not None:
|
| 345 |
+
print(f"Loading pre-generated samples from {args.samples_file} β¦")
|
| 346 |
+
gen_grids = np.load(args.samples_file).astype(np.int8)[:n]
|
| 347 |
+
if gen_grids.shape[1:] != (L, L):
|
| 348 |
+
raise ValueError(
|
| 349 |
+
f"samples-file grid shape {gen_grids.shape[1:]} != ({L},{L})"
|
| 350 |
+
)
|
| 351 |
+
n = min(n, len(gen_grids))
|
| 352 |
+
real_grids = real_grids[:n]
|
| 353 |
+
else:
|
| 354 |
+
print(f"Loading checkpoint from {args.checkpoint} β¦")
|
| 355 |
+
model = load_checkpoint(args.checkpoint)
|
| 356 |
+
key = jax.random.PRNGKey(args.seed)
|
| 357 |
+
gen_grids = generate_grids(model, n, key, L) # (n, L, L), values Β±1
|
| 358 |
+
|
| 359 |
+
print(f"\nL = {L} | N = {n} samples per group | T_C = {T_C:.6f}\n")
|
| 360 |
+
|
| 361 |
+
# ββ Feature matrices ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 362 |
+
X_real = compute_feature_matrix(real_grids, desc="Features: real ")
|
| 363 |
+
X_gen = compute_feature_matrix(gen_grids, desc="Features: generated ")
|
| 364 |
+
|
| 365 |
+
# ββ Per-feature comparison table ββββββββββββββββββββββββββββββββββββββββββ
|
| 366 |
+
print("Per-feature statistics (z-score = ΞΞΌ / Ο_real; '<' marks |z| > 1)\n")
|
| 367 |
+
print_feature_table(X_real, X_gen)
|
| 368 |
+
|
| 369 |
+
# ββ Ensemble statistics βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 370 |
+
print("Ensemble statistics\n")
|
| 371 |
+
print_ensemble_table(ensemble_stats(X_real), ensemble_stats(X_gen))
|
| 372 |
+
|
| 373 |
+
# ββ Mahalanobis distance ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 374 |
+
D, z = mahalanobis_distance(X_real, X_gen)
|
| 375 |
+
print(f"Mahalanobis distance D = {D:.4f}")
|
| 376 |
+
print( " (D measures how many 'std-devs' the generated feature mean sits")
|
| 377 |
+
print( " from the real distribution in the decorrelated feature space.)")
|
| 378 |
+
print()
|
| 379 |
+
print(" Top deviating features:")
|
| 380 |
+
order = np.argsort(np.abs(z))[::-1]
|
| 381 |
+
for i in order[:5]:
|
| 382 |
+
print(f" {FEATURE_NAMES[i]:<11} z = {z[i]:+.3f}")
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if __name__ == "__main__":
|
| 386 |
+
main()
|
examples.png
ADDED
|
ising.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Legacy entry point. This file has been split into:
|
| 2 |
+
model.py β Generator architecture and gen_config
|
| 3 |
+
train.py β training loop (python train.py --help)
|
| 4 |
+
sample.py β checkpoint loading and autoregressive sampling (python sample.py --help)
|
| 5 |
+
"""
|
| 6 |
+
# Re-export model symbols for any code that still does `from ising import ...`
|
| 7 |
+
from model import ( # noqa: F401
|
| 8 |
+
snake_order,
|
| 9 |
+
EmbedderBlock,
|
| 10 |
+
FeedForwardBlock,
|
| 11 |
+
AttentionBlock,
|
| 12 |
+
TransformerLayer,
|
| 13 |
+
Encoder,
|
| 14 |
+
Generator,
|
| 15 |
+
gen_config,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
import train
|
| 20 |
+
train.main()
|
main.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
matplotlib.use('Agg')
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
import sys, jax
|
| 6 |
+
sys.path.insert(0, '.')
|
| 7 |
+
from sample import load_checkpoint, sample_batch, tokens_to_grids
|
| 8 |
+
model = load_checkpoint('checkpoint.eqx')
|
| 9 |
+
lattice_size = model.encoder.embedder_block.lattice_size
|
| 10 |
+
key = jax.random.PRNGKey(42)
|
| 11 |
+
print('Compiling and sampling 16 configurations...')
|
| 12 |
+
tokens = sample_batch(model, 16, key)
|
| 13 |
+
tokens = np.asarray(tokens)
|
| 14 |
+
grids = tokens_to_grids(tokens, lattice_size)
|
| 15 |
+
mags = grids.mean(axis=(1,2))
|
| 16 |
+
print(f'Magnetizations: {np.round(mags, 3)}')
|
| 17 |
+
print(f'Mean |m|: {np.abs(mags).mean():.4f}')
|
| 18 |
+
fig, axes = plt.subplots(2, 8, figsize=(14, 4))
|
| 19 |
+
for i, ax in enumerate(axes.flat):
|
| 20 |
+
ax.imshow(grids[i], cmap='gray', vmin=-1, vmax=1, interpolation='nearest')
|
| 21 |
+
ax.set_title(f'm={mags[i]:.2f}', fontsize=7)
|
| 22 |
+
ax.axis('off')
|
| 23 |
+
fig.suptitle('Sampled Ising configs (L=32, T=T_c, 2 epochs)', fontsize=10)
|
| 24 |
+
plt.tight_layout()
|
| 25 |
+
plt.savefig('samples.png', dpi=150, bbox_inches='tight')
|
| 26 |
+
print('Saved samples.png')
|
metadata.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"lattice_size": 32,
|
| 3 |
+
"sample_count": 10000,
|
| 4 |
+
"temperature": 2.269185314213022,
|
| 5 |
+
"temperature_note": "2D Ising critical temperature T_c = 2J/ln(1+sqrt(2))",
|
| 6 |
+
"coupling": 1.0,
|
| 7 |
+
"spin_values": [
|
| 8 |
+
-1,
|
| 9 |
+
1
|
| 10 |
+
],
|
| 11 |
+
"burn_in_sweeps": 200,
|
| 12 |
+
"sample_interval_sweeps": 5,
|
| 13 |
+
"n_chains": 10,
|
| 14 |
+
"algorithm": "Wolff single-cluster",
|
| 15 |
+
"base_seed": 1778172101,
|
| 16 |
+
"generation_time_s": 4.9
|
| 17 |
+
}
|
model.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer model for autoregressive Ising spin generation.
|
| 2 |
+
|
| 3 |
+
Architecture: causal (GPT-style) transformer with per-site positional
|
| 4 |
+
embeddings in snake (boustrophedon) order. The model is trained to maximise
|
| 5 |
+
p(s_0, s_1, ..., s_{N-1}) = β_t p(s_t | s_0, ..., s_{t-1}), where the spin
|
| 6 |
+
sites are visited in snake order over the LΓL lattice.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from collections.abc import Mapping
|
| 10 |
+
|
| 11 |
+
import equinox as eqx
|
| 12 |
+
import jax
|
| 13 |
+
import jax.numpy as jnp
|
| 14 |
+
import numpy as np
|
| 15 |
+
from jaxtyping import Array, Float, Int
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def snake_order(size: int) -> tuple[np.ndarray, np.ndarray]:
|
| 19 |
+
"""Return (rows, cols) index arrays traversing an LΓL grid in snake order.
|
| 20 |
+
|
| 21 |
+
Even rows go left-to-right; odd rows go right-to-left. The returned
|
| 22 |
+
arrays have length sizeΒ² and implement numpy advanced indexing:
|
| 23 |
+
grid[rows, cols] β 1-D sequence in snake order
|
| 24 |
+
grid[rows, cols] = seq β scatter a sequence back to the grid
|
| 25 |
+
"""
|
| 26 |
+
if size <= 0:
|
| 27 |
+
raise ValueError("size must be positive")
|
| 28 |
+
rows, cols = [], []
|
| 29 |
+
for row in range(size):
|
| 30 |
+
columns = range(size) if row % 2 == 0 else range(size - 1, -1, -1)
|
| 31 |
+
for col in columns:
|
| 32 |
+
rows.append(row)
|
| 33 |
+
cols.append(col)
|
| 34 |
+
return np.array(rows), np.array(cols)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Building blocks
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
class EmbedderBlock(eqx.Module):
|
| 42 |
+
"""Spin-state + lattice-position embedder.
|
| 43 |
+
|
| 44 |
+
Each position in the snake-order sequence gets three embeddings summed:
|
| 45 |
+
β’ a learned spin-state embedding (token β {0, 1})
|
| 46 |
+
β’ a learned row-position embedding
|
| 47 |
+
β’ a learned column-position embedding
|
| 48 |
+
|
| 49 |
+
The row/column indices are derived from `snake_order` at trace time, so
|
| 50 |
+
they fold to compile-time constants β no array model-parameters are stored.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
state_embedder: eqx.nn.Embedding
|
| 54 |
+
row_embedder: eqx.nn.Embedding
|
| 55 |
+
column_embedder: eqx.nn.Embedding
|
| 56 |
+
layernorm: eqx.nn.LayerNorm
|
| 57 |
+
dropout: eqx.nn.Dropout
|
| 58 |
+
lattice_size: int = eqx.field(static=True)
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
state_size: int,
|
| 63 |
+
lattice_size: int,
|
| 64 |
+
embedding_size: int,
|
| 65 |
+
hidden_size: int,
|
| 66 |
+
dropout_rate: float,
|
| 67 |
+
key: jax.random.PRNGKey,
|
| 68 |
+
):
|
| 69 |
+
state_key, row_key, col_key = jax.random.split(key, 3)
|
| 70 |
+
self.state_embedder = eqx.nn.Embedding(
|
| 71 |
+
num_embeddings=state_size, embedding_size=embedding_size, key=state_key
|
| 72 |
+
)
|
| 73 |
+
self.row_embedder = eqx.nn.Embedding(
|
| 74 |
+
num_embeddings=lattice_size, embedding_size=embedding_size, key=row_key
|
| 75 |
+
)
|
| 76 |
+
self.column_embedder = eqx.nn.Embedding(
|
| 77 |
+
num_embeddings=lattice_size, embedding_size=embedding_size, key=col_key
|
| 78 |
+
)
|
| 79 |
+
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
|
| 80 |
+
self.dropout = eqx.nn.Dropout(dropout_rate)
|
| 81 |
+
self.lattice_size = lattice_size
|
| 82 |
+
|
| 83 |
+
def __call__(
|
| 84 |
+
self,
|
| 85 |
+
states: Int[Array, " seq_len"],
|
| 86 |
+
enable_dropout: bool = False,
|
| 87 |
+
key: jax.Array | None = None,
|
| 88 |
+
) -> Float[Array, "seq_len hidden_size"]:
|
| 89 |
+
rows, cols = snake_order(self.lattice_size) # concrete at trace time
|
| 90 |
+
x_states = jax.vmap(self.state_embedder)(states)
|
| 91 |
+
x_rows = jax.vmap(self.row_embedder)(jnp.asarray(rows))
|
| 92 |
+
x_cols = jax.vmap(self.column_embedder)(jnp.asarray(cols))
|
| 93 |
+
x = x_states + x_rows + x_cols
|
| 94 |
+
x = jax.vmap(self.layernorm)(x)
|
| 95 |
+
x = self.dropout(x, inference=not enable_dropout, key=key)
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class FeedForwardBlock(eqx.Module):
|
| 100 |
+
"""Position-wise feed-forward block with residual connection."""
|
| 101 |
+
|
| 102 |
+
mlp: eqx.nn.Linear
|
| 103 |
+
output: eqx.nn.Linear
|
| 104 |
+
layernorm: eqx.nn.LayerNorm
|
| 105 |
+
dropout: eqx.nn.Dropout
|
| 106 |
+
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
hidden_size: int,
|
| 110 |
+
intermediate_size: int,
|
| 111 |
+
dropout_rate: float,
|
| 112 |
+
key: jax.random.PRNGKey,
|
| 113 |
+
):
|
| 114 |
+
mlp_key, out_key = jax.random.split(key)
|
| 115 |
+
self.mlp = eqx.nn.Linear(hidden_size, intermediate_size, key=mlp_key)
|
| 116 |
+
self.output = eqx.nn.Linear(intermediate_size, hidden_size, key=out_key)
|
| 117 |
+
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
|
| 118 |
+
self.dropout = eqx.nn.Dropout(dropout_rate)
|
| 119 |
+
|
| 120 |
+
def __call__(
|
| 121 |
+
self,
|
| 122 |
+
inputs: Float[Array, " hidden_size"],
|
| 123 |
+
enable_dropout: bool = False,
|
| 124 |
+
key: jax.Array | None = None,
|
| 125 |
+
) -> Float[Array, " hidden_size"]:
|
| 126 |
+
x = jax.nn.gelu(self.mlp(inputs))
|
| 127 |
+
x = self.output(x)
|
| 128 |
+
x = self.dropout(x, inference=not enable_dropout, key=key)
|
| 129 |
+
x = x + inputs
|
| 130 |
+
x = self.layernorm(x)
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class AttentionBlock(eqx.Module):
|
| 135 |
+
"""Multi-head self-attention with causal (lower-triangular) mask."""
|
| 136 |
+
|
| 137 |
+
attention: eqx.nn.MultiheadAttention
|
| 138 |
+
layernorm: eqx.nn.LayerNorm
|
| 139 |
+
dropout: eqx.nn.Dropout
|
| 140 |
+
num_heads: int = eqx.field(static=True)
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
hidden_size: int,
|
| 145 |
+
num_heads: int,
|
| 146 |
+
dropout_rate: float,
|
| 147 |
+
attention_dropout_rate: float,
|
| 148 |
+
key: jax.random.PRNGKey,
|
| 149 |
+
):
|
| 150 |
+
self.num_heads = num_heads
|
| 151 |
+
self.attention = eqx.nn.MultiheadAttention(
|
| 152 |
+
num_heads=num_heads,
|
| 153 |
+
query_size=hidden_size,
|
| 154 |
+
use_query_bias=True,
|
| 155 |
+
use_key_bias=True,
|
| 156 |
+
use_value_bias=True,
|
| 157 |
+
use_output_bias=True,
|
| 158 |
+
dropout_p=attention_dropout_rate,
|
| 159 |
+
key=key,
|
| 160 |
+
)
|
| 161 |
+
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
|
| 162 |
+
self.dropout = eqx.nn.Dropout(dropout_rate)
|
| 163 |
+
|
| 164 |
+
def __call__(
|
| 165 |
+
self,
|
| 166 |
+
inputs: Float[Array, "seq_len hidden_size"],
|
| 167 |
+
mask: Int[Array, " seq_len"] | None,
|
| 168 |
+
enable_dropout: bool = False,
|
| 169 |
+
key: jax.random.PRNGKey = None,
|
| 170 |
+
) -> Float[Array, "seq_len hidden_size"]:
|
| 171 |
+
attn_key, drop_key = (None, None) if key is None else jax.random.split(key)
|
| 172 |
+
if mask is not None:
|
| 173 |
+
mask = self._causal_mask(mask)
|
| 174 |
+
x = self.attention(
|
| 175 |
+
query=inputs, key_=inputs, value=inputs,
|
| 176 |
+
mask=mask, inference=not enable_dropout, key=attn_key,
|
| 177 |
+
)
|
| 178 |
+
x = self.dropout(x, inference=not enable_dropout, key=drop_key)
|
| 179 |
+
x = x + inputs
|
| 180 |
+
x = jax.vmap(self.layernorm)(x)
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
def _causal_mask(
|
| 184 |
+
self, mask: Int[Array, " seq_len"]
|
| 185 |
+
) -> Float[Array, "num_heads seq_len seq_len"]:
|
| 186 |
+
"""Lower-triangular mask combined with a padding mask."""
|
| 187 |
+
n = mask.shape[0]
|
| 188 |
+
pad = jnp.multiply(mask[:, None], mask[None, :]) # [n, n]
|
| 189 |
+
causal = jnp.tril(jnp.ones((n, n), dtype=mask.dtype)) # [n, n]
|
| 190 |
+
m = jnp.multiply(pad, causal) # [n, n]
|
| 191 |
+
m = jnp.broadcast_to(m[None], (self.num_heads, n, n)) # [H, n, n]
|
| 192 |
+
return m.astype(jnp.float32)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class TransformerLayer(eqx.Module):
|
| 196 |
+
"""One transformer block: attention followed by feed-forward."""
|
| 197 |
+
|
| 198 |
+
attention_block: AttentionBlock
|
| 199 |
+
ff_block: FeedForwardBlock
|
| 200 |
+
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
hidden_size: int,
|
| 204 |
+
intermediate_size: int,
|
| 205 |
+
num_heads: int,
|
| 206 |
+
dropout_rate: float,
|
| 207 |
+
attention_dropout_rate: float,
|
| 208 |
+
key: jax.random.PRNGKey,
|
| 209 |
+
):
|
| 210 |
+
attn_key, ff_key = jax.random.split(key)
|
| 211 |
+
self.attention_block = AttentionBlock(
|
| 212 |
+
hidden_size=hidden_size,
|
| 213 |
+
num_heads=num_heads,
|
| 214 |
+
dropout_rate=dropout_rate,
|
| 215 |
+
attention_dropout_rate=attention_dropout_rate,
|
| 216 |
+
key=attn_key,
|
| 217 |
+
)
|
| 218 |
+
self.ff_block = FeedForwardBlock(
|
| 219 |
+
hidden_size=hidden_size,
|
| 220 |
+
intermediate_size=intermediate_size,
|
| 221 |
+
dropout_rate=dropout_rate,
|
| 222 |
+
key=ff_key,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def __call__(
|
| 226 |
+
self,
|
| 227 |
+
inputs: Float[Array, "seq_len hidden_size"],
|
| 228 |
+
mask: Int[Array, " seq_len"] | None = None,
|
| 229 |
+
*,
|
| 230 |
+
enable_dropout: bool = False,
|
| 231 |
+
key: jax.Array | None = None,
|
| 232 |
+
) -> Float[Array, "seq_len hidden_size"]:
|
| 233 |
+
attn_key, ff_key = (None, None) if key is None else jax.random.split(key)
|
| 234 |
+
x = self.attention_block(inputs, mask, enable_dropout=enable_dropout, key=attn_key)
|
| 235 |
+
n = x.shape[0]
|
| 236 |
+
ff_keys = None if ff_key is None else jax.random.split(ff_key, n)
|
| 237 |
+
x = jax.vmap(self.ff_block, in_axes=(0, None, 0))(x, enable_dropout, ff_keys)
|
| 238 |
+
return x
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ---------------------------------------------------------------------------
|
| 242 |
+
# Encoder and top-level Generator
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
|
| 245 |
+
class Encoder(eqx.Module):
|
| 246 |
+
"""Stack of transformer layers over a snake-ordered spin sequence."""
|
| 247 |
+
|
| 248 |
+
embedder_block: EmbedderBlock
|
| 249 |
+
layers: list[TransformerLayer]
|
| 250 |
+
|
| 251 |
+
def __init__(
|
| 252 |
+
self,
|
| 253 |
+
state_size: int,
|
| 254 |
+
lattice_size: int,
|
| 255 |
+
embedding_size: int,
|
| 256 |
+
hidden_size: int,
|
| 257 |
+
intermediate_size: int,
|
| 258 |
+
num_layers: int,
|
| 259 |
+
num_heads: int,
|
| 260 |
+
dropout_rate: float,
|
| 261 |
+
attention_dropout_rate: float,
|
| 262 |
+
key: jax.random.PRNGKey,
|
| 263 |
+
):
|
| 264 |
+
emb_key, layer_key = jax.random.split(key)
|
| 265 |
+
self.embedder_block = EmbedderBlock(
|
| 266 |
+
state_size=state_size,
|
| 267 |
+
lattice_size=lattice_size,
|
| 268 |
+
embedding_size=embedding_size,
|
| 269 |
+
hidden_size=hidden_size,
|
| 270 |
+
dropout_rate=dropout_rate,
|
| 271 |
+
key=emb_key,
|
| 272 |
+
)
|
| 273 |
+
layer_keys = jax.random.split(layer_key, num_layers)
|
| 274 |
+
self.layers = [
|
| 275 |
+
TransformerLayer(
|
| 276 |
+
hidden_size=hidden_size,
|
| 277 |
+
intermediate_size=intermediate_size,
|
| 278 |
+
num_heads=num_heads,
|
| 279 |
+
dropout_rate=dropout_rate,
|
| 280 |
+
attention_dropout_rate=attention_dropout_rate,
|
| 281 |
+
key=lk,
|
| 282 |
+
)
|
| 283 |
+
for lk in layer_keys
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
def __call__(
|
| 287 |
+
self,
|
| 288 |
+
states: Int[Array, " seq_len"],
|
| 289 |
+
*,
|
| 290 |
+
enable_dropout: bool = False,
|
| 291 |
+
key: jax.Array | None = None,
|
| 292 |
+
) -> Float[Array, "seq_len hidden_size"]:
|
| 293 |
+
emb_key, l_key = (None, None) if key is None else jax.random.split(key)
|
| 294 |
+
x = self.embedder_block(states, enable_dropout=enable_dropout, key=emb_key)
|
| 295 |
+
mask = jnp.ones_like(states, dtype=jnp.int32) # no padding; causal only
|
| 296 |
+
for layer in self.layers:
|
| 297 |
+
cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key)
|
| 298 |
+
x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key)
|
| 299 |
+
return x
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class Generator(eqx.Module):
|
| 303 |
+
"""Autoregressive transformer generator for Ising spin configurations.
|
| 304 |
+
|
| 305 |
+
Input: token_ids β integer spin tokens {0=down, 1=up} in snake order.
|
| 306 |
+
Output: logits β shape (seq_len, state_size), where logits[t] is the
|
| 307 |
+
predicted distribution over the spin at position t+1
|
| 308 |
+
given positions 0..t.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
encoder: Encoder
|
| 312 |
+
lm_head: eqx.nn.Linear
|
| 313 |
+
dropout: eqx.nn.Dropout
|
| 314 |
+
|
| 315 |
+
def __init__(self, config: Mapping, key: jax.random.PRNGKey):
|
| 316 |
+
enc_key, head_key = jax.random.split(key)
|
| 317 |
+
self.encoder = Encoder(
|
| 318 |
+
state_size=config["state_size"],
|
| 319 |
+
lattice_size=config["lattice_size"],
|
| 320 |
+
embedding_size=config["hidden_size"],
|
| 321 |
+
hidden_size=config["hidden_size"],
|
| 322 |
+
intermediate_size=config["intermediate_size"],
|
| 323 |
+
num_layers=config["num_hidden_layers"],
|
| 324 |
+
num_heads=config["num_attention_heads"],
|
| 325 |
+
dropout_rate=config["hidden_dropout_prob"],
|
| 326 |
+
attention_dropout_rate=config["attention_probs_dropout_prob"],
|
| 327 |
+
key=enc_key,
|
| 328 |
+
)
|
| 329 |
+
self.lm_head = eqx.nn.Linear(
|
| 330 |
+
in_features=config["hidden_size"],
|
| 331 |
+
out_features=config["state_size"],
|
| 332 |
+
key=head_key,
|
| 333 |
+
)
|
| 334 |
+
self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"])
|
| 335 |
+
|
| 336 |
+
def __call__(
|
| 337 |
+
self,
|
| 338 |
+
inputs: dict[str, Int[Array, " seq_len"]],
|
| 339 |
+
enable_dropout: bool = False,
|
| 340 |
+
key: jax.random.PRNGKey = None,
|
| 341 |
+
) -> Float[Array, "seq_len state_size"]:
|
| 342 |
+
e_key, d_key = (None, None) if key is None else jax.random.split(key)
|
| 343 |
+
x = self.encoder(inputs["token_ids"], enable_dropout=enable_dropout, key=e_key)
|
| 344 |
+
x = self.dropout(x, inference=not enable_dropout, key=d_key)
|
| 345 |
+
return jax.vmap(self.lm_head)(x)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# ---------------------------------------------------------------------------
|
| 349 |
+
# Default configuration
|
| 350 |
+
# ---------------------------------------------------------------------------
|
| 351 |
+
|
| 352 |
+
gen_config = {
|
| 353 |
+
"state_size": 2, # spin tokens: 0 (β) or 1 (β)
|
| 354 |
+
"lattice_size": 32, # LΓL lattice β LΒ² = 1024 sequence length
|
| 355 |
+
"hidden_size": 128,
|
| 356 |
+
"num_hidden_layers": 2,
|
| 357 |
+
"num_attention_heads": 2,
|
| 358 |
+
"hidden_act": "gelu",
|
| 359 |
+
"intermediate_size": 512,
|
| 360 |
+
"hidden_dropout_prob": 0.1,
|
| 361 |
+
"attention_probs_dropout_prob": 0.1,
|
| 362 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
jax[cuda12]
|
| 2 |
+
equinox
|
| 3 |
+
optax
|
| 4 |
+
einops
|
| 5 |
+
tqdm
|
| 6 |
+
jaxtyping
|
| 7 |
+
gradio>=4.0
|
| 8 |
+
matplotlib
|
| 9 |
+
scipy
|
sample.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# /// script
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "jax[cuda12]",
|
| 5 |
+
# "equinox",
|
| 6 |
+
# "matplotlib",
|
| 7 |
+
# "jaxtyping",
|
| 8 |
+
# ]
|
| 9 |
+
# ///
|
| 10 |
+
"""Sample spin configurations from a trained Ising Generator checkpoint.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python sample.py --checkpoint model.eqx [--num-samples 16]
|
| 14 |
+
[--output samples.npy] [--plot] [--seed 0]
|
| 15 |
+
|
| 16 |
+
How autoregressive sampling works
|
| 17 |
+
----------------------------------
|
| 18 |
+
The model is trained with a causal (lower-triangular) attention mask, so at
|
| 19 |
+
position t the output logits[t] are a function of spins s_0 β¦ s_t only.
|
| 20 |
+
We exploit this to sample the full sequence one spin at a time:
|
| 21 |
+
|
| 22 |
+
1. Sample s_0 uniformly (the model has no BOS token).
|
| 23 |
+
2. For t = 0, 1, β¦, LΒ²-2:
|
| 24 |
+
a. Run the full forward pass on the current token buffer.
|
| 25 |
+
Spins at positions > t are still placeholder zeros, but causal
|
| 26 |
+
masking prevents the network from attending to them.
|
| 27 |
+
b. Draw s_{t+1} ~ Categorical(softmax(logits[t])).
|
| 28 |
+
c. Write s_{t+1} into the buffer.
|
| 29 |
+
|
| 30 |
+
This is O(Lβ΄) in compute (LΒ² steps Γ LΒ² attention), which is 1 B ops for a
|
| 31 |
+
32Γ32 lattice. `jax.lax.scan` compiles the loop body once so subsequent
|
| 32 |
+
calls are fast.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import functools
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
|
| 39 |
+
import equinox as eqx
|
| 40 |
+
import jax
|
| 41 |
+
import jax.numpy as jnp
|
| 42 |
+
import numpy as np
|
| 43 |
+
from jaxtyping import Array, Int
|
| 44 |
+
|
| 45 |
+
from model import Generator, gen_config, snake_order
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# Checkpoint I/O
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
def load_checkpoint(
|
| 53 |
+
path: Path,
|
| 54 |
+
config: dict = gen_config,
|
| 55 |
+
key: jax.Array | None = None,
|
| 56 |
+
) -> Generator:
|
| 57 |
+
"""Deserialise a Generator from *path*.
|
| 58 |
+
|
| 59 |
+
A fresh model is initialised with *config* (weights are immediately
|
| 60 |
+
overwritten), so *key* only needs to be reproducible across calls if you
|
| 61 |
+
care about the random seed used for the skeleton β in practice any key
|
| 62 |
+
works.
|
| 63 |
+
"""
|
| 64 |
+
if key is None:
|
| 65 |
+
key = jax.random.PRNGKey(0)
|
| 66 |
+
skeleton = Generator(config=config, key=key)
|
| 67 |
+
return eqx.tree_deserialise_leaves(path, skeleton)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
# Sampling
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
def sample_sequence(
|
| 75 |
+
model: Generator,
|
| 76 |
+
key: jax.random.PRNGKey,
|
| 77 |
+
) -> Int[Array, " seq_len"]:
|
| 78 |
+
"""Autoregressively sample one spin sequence in snake order.
|
| 79 |
+
|
| 80 |
+
This function is JAX-traceable and safe to use inside ``jax.vmap`` or
|
| 81 |
+
``jax.lax.scan``. JIT-compile it (or wrap in ``sample_batch``) for best
|
| 82 |
+
performance β the first call will take longer due to compilation.
|
| 83 |
+
"""
|
| 84 |
+
lattice_size = model.encoder.embedder_block.lattice_size
|
| 85 |
+
seq_len = lattice_size * lattice_size
|
| 86 |
+
|
| 87 |
+
def step(carry, t):
|
| 88 |
+
tokens, step_key = carry
|
| 89 |
+
step_key, sample_key = jax.random.split(step_key)
|
| 90 |
+
logits = model({"token_ids": tokens}, enable_dropout=False, key=None)
|
| 91 |
+
# logits[t] β distribution over s_{t+1}
|
| 92 |
+
next_token = jax.random.categorical(sample_key, logits[t])
|
| 93 |
+
tokens = tokens.at[t + 1].set(next_token)
|
| 94 |
+
return (tokens, step_key), None
|
| 95 |
+
|
| 96 |
+
key, first_key = jax.random.split(key)
|
| 97 |
+
first_token = jax.random.randint(first_key, shape=(), minval=0, maxval=2)
|
| 98 |
+
tokens = jnp.zeros(seq_len, dtype=jnp.int32).at[0].set(first_token)
|
| 99 |
+
(tokens, _), _ = jax.lax.scan(step, (tokens, key), jnp.arange(seq_len - 1))
|
| 100 |
+
return tokens
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@eqx.filter_jit
|
| 104 |
+
def sample_batch(
|
| 105 |
+
model: Generator,
|
| 106 |
+
num_samples: int,
|
| 107 |
+
key: jax.random.PRNGKey,
|
| 108 |
+
) -> Int[Array, "num_samples seq_len"]:
|
| 109 |
+
"""Sample *num_samples* configurations in parallel (vmapped + JIT'd).
|
| 110 |
+
|
| 111 |
+
The first call triggers compilation; subsequent calls with the same
|
| 112 |
+
``num_samples`` reuse the compiled code.
|
| 113 |
+
"""
|
| 114 |
+
keys = jax.random.split(key, num_samples)
|
| 115 |
+
return jax.vmap(sample_sequence, in_axes=(None, 0))(model, keys)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ---------------------------------------------------------------------------
|
| 119 |
+
# Grid conversion
|
| 120 |
+
# ---------------------------------------------------------------------------
|
| 121 |
+
|
| 122 |
+
def tokens_to_grid(
|
| 123 |
+
tokens: np.ndarray | Array,
|
| 124 |
+
lattice_size: int,
|
| 125 |
+
) -> np.ndarray:
|
| 126 |
+
"""Convert a snake-ordered token sequence {0, 1} β a {-1, +1} LΓL grid."""
|
| 127 |
+
tokens = np.asarray(tokens)
|
| 128 |
+
rows, cols = snake_order(lattice_size)
|
| 129 |
+
grid = np.empty((lattice_size, lattice_size), dtype=np.int8)
|
| 130 |
+
grid[rows, cols] = (tokens * 2 - 1).astype(np.int8)
|
| 131 |
+
return grid
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def tokens_to_grids(
|
| 135 |
+
tokens: np.ndarray | Array,
|
| 136 |
+
lattice_size: int,
|
| 137 |
+
) -> np.ndarray:
|
| 138 |
+
"""Batch version of ``tokens_to_grid``. Input shape: (N, LΒ²)."""
|
| 139 |
+
tokens = np.asarray(tokens)
|
| 140 |
+
return np.stack([tokens_to_grid(tokens[i], lattice_size) for i in range(len(tokens))])
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
# CLI
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
|
| 147 |
+
def parse_args():
|
| 148 |
+
p = argparse.ArgumentParser(description="Sample from a trained Ising Generator.")
|
| 149 |
+
p.add_argument("--checkpoint", type=Path, required=True,
|
| 150 |
+
help="Path to the .eqx checkpoint file.")
|
| 151 |
+
p.add_argument("--num-samples", type=int, default=16)
|
| 152 |
+
p.add_argument("--output", type=Path, default=None,
|
| 153 |
+
help="Save sampled {-1,+1} grids as a .npy file (N, L, L).")
|
| 154 |
+
p.add_argument("--plot", action="store_true",
|
| 155 |
+
help="Display a grid of sampled configurations with matplotlib.")
|
| 156 |
+
p.add_argument("--seed", type=int, default=0)
|
| 157 |
+
return p.parse_args()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def main():
|
| 161 |
+
args = parse_args()
|
| 162 |
+
|
| 163 |
+
print(f"Loading checkpoint from {args.checkpoint} β¦")
|
| 164 |
+
model = load_checkpoint(args.checkpoint)
|
| 165 |
+
lattice_size = model.encoder.embedder_block.lattice_size
|
| 166 |
+
print(f" lattice_size={lattice_size}, seq_len={lattice_size**2}")
|
| 167 |
+
|
| 168 |
+
key = jax.random.PRNGKey(args.seed)
|
| 169 |
+
print(f"Sampling {args.num_samples} configurations "
|
| 170 |
+
f"(compiling on first call) β¦")
|
| 171 |
+
tokens = sample_batch(model, args.num_samples, key)
|
| 172 |
+
tokens = np.asarray(tokens)
|
| 173 |
+
|
| 174 |
+
grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1}
|
| 175 |
+
print(f" shape: {grids.shape} dtype: {grids.dtype}")
|
| 176 |
+
print(f" mean magnetization : {grids.mean():.4f}")
|
| 177 |
+
print(f" mean |magnetization|: {np.abs(grids.mean(axis=(1, 2))).mean():.4f}")
|
| 178 |
+
|
| 179 |
+
if args.output is not None:
|
| 180 |
+
np.save(args.output, grids)
|
| 181 |
+
print(f"Saved β {args.output}")
|
| 182 |
+
|
| 183 |
+
if args.plot:
|
| 184 |
+
try:
|
| 185 |
+
import matplotlib.pyplot as plt
|
| 186 |
+
except ImportError:
|
| 187 |
+
print("matplotlib not available; skipping plot (pip install matplotlib).")
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
cols = min(8, args.num_samples)
|
| 191 |
+
rows = (args.num_samples + cols - 1) // cols
|
| 192 |
+
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
|
| 193 |
+
axes = np.array(axes).reshape(-1)
|
| 194 |
+
for i, ax in enumerate(axes):
|
| 195 |
+
if i < len(grids):
|
| 196 |
+
ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest")
|
| 197 |
+
ax.axis("off")
|
| 198 |
+
fig.suptitle(
|
| 199 |
+
f"Sampled Ising configurations "
|
| 200 |
+
f"(L={lattice_size}, n={args.num_samples})",
|
| 201 |
+
fontsize=10,
|
| 202 |
+
)
|
| 203 |
+
plt.tight_layout()
|
| 204 |
+
plt.show()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
main()
|
samples-2-epoch.png
ADDED
|
spins.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1890a0f2baeee4212b3bd79b74bff1f1f31135d9031ffb8c79bc558eae23a83
|
| 3 |
+
size 10240128
|
spins_test.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:040dea2a4f3b5753c7531d7a55643a7be3c8be4a7d9b9c1f11f481ea522ff67e
|
| 3 |
+
size 1024128
|
train.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# /// script
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "jax[cuda12]",
|
| 5 |
+
# "equinox",
|
| 6 |
+
# "optax",
|
| 7 |
+
# "einops",
|
| 8 |
+
# "tqdm",
|
| 9 |
+
# "jaxtyping",
|
| 10 |
+
# ]
|
| 11 |
+
# ///
|
| 12 |
+
"""Training script for the Ising spin Generator.
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python train.py [--epochs N] [--batch-size B] [--learning-rate LR]
|
| 16 |
+
[--data path/to/spins.npy] [--output-checkpoint model.eqx]
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import functools
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import einops
|
| 24 |
+
import equinox as eqx
|
| 25 |
+
import jax
|
| 26 |
+
import jax.numpy as jnp
|
| 27 |
+
import numpy as np
|
| 28 |
+
import optax
|
| 29 |
+
from tqdm.auto import tqdm
|
| 30 |
+
|
| 31 |
+
from model import Generator, gen_config, snake_order
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Data loading
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
def load_ising_data(
|
| 39 |
+
path: Path, train_frac: float = 0.9
|
| 40 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 41 |
+
"""Load spins.npy, map {-1,1} β {0,1}, flatten with snake ordering.
|
| 42 |
+
|
| 43 |
+
Returns ``(train_tokens, val_tokens)``, each ``(N, LΒ²)`` int32.
|
| 44 |
+
"""
|
| 45 |
+
spins = np.load(path) # (N, L, L) int8
|
| 46 |
+
lattice_size = spins.shape[1]
|
| 47 |
+
tokens = (spins.astype(np.int32) + 1) // 2 # (N, L, L), values in {0,1}
|
| 48 |
+
rows, cols = snake_order(lattice_size)
|
| 49 |
+
tokens = tokens[:, rows, cols] # (N, LΒ²)
|
| 50 |
+
n_train = int(len(tokens) * train_frac)
|
| 51 |
+
return tokens[:n_train], tokens[n_train:]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Batch preparation
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
def prepare_batch(batch: np.ndarray, num_devices: int) -> dict:
|
| 59 |
+
"""Reshape ``(batch, seq)`` β ``(devices, batch//devices, seq)`` for pmap."""
|
| 60 |
+
token_ids = einops.rearrange(
|
| 61 |
+
batch,
|
| 62 |
+
"(devices batch) seq -> devices batch seq",
|
| 63 |
+
devices=num_devices,
|
| 64 |
+
)
|
| 65 |
+
return {"token_ids": token_ids}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
# Training / eval steps
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
@eqx.filter_value_and_grad
|
| 73 |
+
def compute_loss(model, inputs, key):
|
| 74 |
+
"""Autoregressive cross-entropy: logits[:, :-1] predicts token_ids[:, 1:]."""
|
| 75 |
+
batch_size = inputs["token_ids"].shape[0]
|
| 76 |
+
keys = jax.random.split(key, batch_size)
|
| 77 |
+
logits = jax.vmap(model, in_axes=(0, None, 0))(inputs, True, keys)
|
| 78 |
+
return jnp.mean(
|
| 79 |
+
optax.softmax_cross_entropy_with_integer_labels(
|
| 80 |
+
logits=logits[:, :-1, :],
|
| 81 |
+
labels=inputs["token_ids"][:, 1:],
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def make_step(model, inputs, opt_state, key, tx):
|
| 87 |
+
key, new_key = jax.random.split(key)
|
| 88 |
+
loss, grads = compute_loss(model, inputs, key)
|
| 89 |
+
grads = jax.lax.pmean(grads, axis_name="devices")
|
| 90 |
+
updates, opt_state = tx.update(grads, opt_state, model)
|
| 91 |
+
model = eqx.apply_updates(model, updates)
|
| 92 |
+
return loss, model, opt_state, new_key
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def make_eval_step(model, inputs):
|
| 96 |
+
"""Per-device mean NLL (nats/token), called inside pmap."""
|
| 97 |
+
logits = jax.vmap(functools.partial(model, enable_dropout=False))(inputs)
|
| 98 |
+
return jnp.mean(
|
| 99 |
+
optax.softmax_cross_entropy_with_integer_labels(
|
| 100 |
+
logits=logits[:, :-1, :],
|
| 101 |
+
labels=inputs["token_ids"][:, 1:],
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
p_make_eval_step = eqx.filter_pmap(make_eval_step)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# pmap helpers
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def replicate_for_pmap(value, devices):
|
| 114 |
+
mesh = jax.sharding.Mesh(np.asarray(devices), ("devices",))
|
| 115 |
+
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("devices"))
|
| 116 |
+
|
| 117 |
+
def replicate_leaf(leaf):
|
| 118 |
+
leaf = jnp.asarray(leaf)
|
| 119 |
+
leaf = jnp.broadcast_to(leaf, (len(devices),) + leaf.shape)
|
| 120 |
+
return jax.device_put(leaf, sharding)
|
| 121 |
+
|
| 122 |
+
return jax.tree.map(replicate_leaf, value)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def unreplicate_from_pmap(value):
|
| 126 |
+
return jax.tree.map(lambda leaf: leaf[0], value)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
# CLI
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
def parse_args():
|
| 134 |
+
p = argparse.ArgumentParser(description="Train an Ising spin generator.")
|
| 135 |
+
p.add_argument("--epochs", type=int, default=10)
|
| 136 |
+
p.add_argument("--batch-size", type=int, default=32)
|
| 137 |
+
p.add_argument("--learning-rate", type=float, default=1e-4)
|
| 138 |
+
p.add_argument("--max-train-steps", type=int, default=None)
|
| 139 |
+
p.add_argument("--max-eval-batches", type=int, default=None)
|
| 140 |
+
p.add_argument("--data", type=Path,
|
| 141 |
+
default=Path(__file__).parent / "spins.npy")
|
| 142 |
+
p.add_argument("--output-checkpoint", type=Path, default=None)
|
| 143 |
+
p.add_argument("--seed", type=int, default=5678)
|
| 144 |
+
return p.parse_args()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
args = parse_args()
|
| 149 |
+
num_devices = jax.device_count()
|
| 150 |
+
print(f"JAX devices: {jax.devices()}")
|
| 151 |
+
assert args.batch_size % num_devices == 0, (
|
| 152 |
+
"batch-size must be a multiple of the number of devices"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
key = jax.random.PRNGKey(args.seed)
|
| 156 |
+
model_key, train_key = jax.random.split(key)
|
| 157 |
+
model = Generator(config=gen_config, key=model_key)
|
| 158 |
+
|
| 159 |
+
train_tokens, val_tokens = load_ising_data(args.data)
|
| 160 |
+
print(
|
| 161 |
+
f"Train: {len(train_tokens):,} Val: {len(val_tokens):,} "
|
| 162 |
+
f"Seq len: {train_tokens.shape[1]}"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
tx = optax.chain(
|
| 166 |
+
optax.clip_by_global_norm(1.0),
|
| 167 |
+
optax.adam(learning_rate=args.learning_rate),
|
| 168 |
+
)
|
| 169 |
+
# Mask to float-only leaves so integer bookkeeping fields are excluded.
|
| 170 |
+
tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model))
|
| 171 |
+
opt_state = tx.init(model)
|
| 172 |
+
|
| 173 |
+
p_make_step = eqx.filter_pmap(
|
| 174 |
+
functools.partial(make_step, tx=tx), axis_name="devices"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
devices = jax.local_devices()
|
| 178 |
+
opt_state = replicate_for_pmap(opt_state, devices)
|
| 179 |
+
model = replicate_for_pmap(model, devices)
|
| 180 |
+
train_key = replicate_for_pmap(train_key, devices)
|
| 181 |
+
|
| 182 |
+
global_step = 0
|
| 183 |
+
for epoch in range(args.epochs):
|
| 184 |
+
rng = np.random.default_rng(args.seed + epoch)
|
| 185 |
+
shuffled = train_tokens[rng.permutation(len(train_tokens))]
|
| 186 |
+
|
| 187 |
+
num_batches = len(shuffled) // args.batch_size
|
| 188 |
+
if args.max_train_steps is not None:
|
| 189 |
+
num_batches = min(num_batches, max(args.max_train_steps - global_step, 0))
|
| 190 |
+
|
| 191 |
+
with tqdm(range(num_batches), unit="steps",
|
| 192 |
+
desc=f"Epoch {epoch + 1}/{args.epochs}") as pbar:
|
| 193 |
+
for step in pbar:
|
| 194 |
+
batch = shuffled[step * args.batch_size : (step + 1) * args.batch_size]
|
| 195 |
+
inputs = prepare_batch(batch, num_devices)
|
| 196 |
+
loss, model, opt_state, train_key = p_make_step(
|
| 197 |
+
model, inputs, opt_state, train_key
|
| 198 |
+
)
|
| 199 |
+
global_step += 1
|
| 200 |
+
pbar.set_postfix(loss=float(np.sum(loss)))
|
| 201 |
+
if args.max_train_steps and global_step >= args.max_train_steps:
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
if args.max_train_steps and global_step >= args.max_train_steps:
|
| 205 |
+
break
|
| 206 |
+
|
| 207 |
+
# ---- validation ----
|
| 208 |
+
num_val = len(val_tokens) // args.batch_size
|
| 209 |
+
if args.max_eval_batches is not None:
|
| 210 |
+
num_val = min(num_val, args.max_eval_batches)
|
| 211 |
+
|
| 212 |
+
val_losses = []
|
| 213 |
+
for step in tqdm(range(num_val), unit="steps", desc="Validation"):
|
| 214 |
+
batch = val_tokens[step * args.batch_size : (step + 1) * args.batch_size]
|
| 215 |
+
inputs = prepare_batch(batch, num_devices)
|
| 216 |
+
val_losses.append(float(np.mean(p_make_eval_step(model, inputs))))
|
| 217 |
+
|
| 218 |
+
print(f"Val NLL: {np.mean(val_losses):.4f} nats/token")
|
| 219 |
+
|
| 220 |
+
if args.output_checkpoint is not None:
|
| 221 |
+
args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 222 |
+
eqx.tree_serialise_leaves(
|
| 223 |
+
args.output_checkpoint, unreplicate_from_pmap(model)
|
| 224 |
+
)
|
| 225 |
+
print(f"Saved checkpoint β {args.output_checkpoint}")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
main()
|
vi_train.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# /// script
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "jax[cuda12]",
|
| 5 |
+
# "equinox",
|
| 6 |
+
# "optax",
|
| 7 |
+
# "tqdm",
|
| 8 |
+
# "jaxtyping",
|
| 9 |
+
# ]
|
| 10 |
+
# ///
|
| 11 |
+
"""Variational inference fine-tuning of the Ising Generator.
|
| 12 |
+
|
| 13 |
+
Objective
|
| 14 |
+
---------
|
| 15 |
+
Minimise the variational free energy
|
| 16 |
+
|
| 17 |
+
F[q] = E_{s~q}[E(s)] β T Β· H[q]
|
| 18 |
+
= T Β· E_{s~q}[ E(s)/T + log q(s) ]
|
| 19 |
+
|
| 20 |
+
which equals KL(q β₯ p*) up to the constant log Z, where
|
| 21 |
+
p*(s) β exp(βE(s)/T) is the Ising Boltzmann distribution.
|
| 22 |
+
As F decreases, the model q approaches the correct physics.
|
| 23 |
+
|
| 24 |
+
Gradient estimator (REINFORCE / score-function)
|
| 25 |
+
------------------------------------------------
|
| 26 |
+
β_ΞΈ F = T Β· E_{s~q}[( E(s)/T + log q(s) β b ) Β· β_ΞΈ log q(s)]
|
| 27 |
+
|
| 28 |
+
where b = batch-mean reward is a zero-variance control variate.
|
| 29 |
+
|
| 30 |
+
Per training step
|
| 31 |
+
-----------------
|
| 32 |
+
1. Sample a batch of configs from the current model q_ΞΈ (slow on CPU)
|
| 33 |
+
2. Compute Ising energy E(s) for each sample (fast, no model)
|
| 34 |
+
3. Compute log q(s) via a teacher-forced forward pass (fast, one pass)
|
| 35 |
+
4. Assemble reward R = E/T + log q, subtract baseline, backprop (fast)
|
| 36 |
+
|
| 37 |
+
Speed note
|
| 38 |
+
----------
|
| 39 |
+
Step 1 dominates on CPU (~5 s / sample on a 32Γ32 lattice). On GPU it is
|
| 40 |
+
typically 10β100Γ faster and VI training with batch_size β₯ 32 is practical.
|
| 41 |
+
The --checkpoint flag warm-starts from a CE-pretrained model, which dramati-
|
| 42 |
+
cally reduces the number of VI steps needed to converge.
|
| 43 |
+
|
| 44 |
+
Monitoring
|
| 45 |
+
----------
|
| 46 |
+
At each step the following quantities are logged:
|
| 47 |
+
|
| 48 |
+
e = β¨E/Nβ© mean energy per spin (converges to data ~β1.45)
|
| 49 |
+
h = ββ¨log qβ©/N entropy per spin in nats (random init β 0.693)
|
| 50 |
+
f = e β TΒ·h Helmholtz free energy per spin (we minimise this)
|
| 51 |
+
|m| = β¨|Ξ£s / N|β© mean absolute magnetisation
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
import argparse
|
| 55 |
+
from pathlib import Path
|
| 56 |
+
|
| 57 |
+
import equinox as eqx
|
| 58 |
+
import jax
|
| 59 |
+
import jax.numpy as jnp
|
| 60 |
+
import numpy as np
|
| 61 |
+
import optax
|
| 62 |
+
from tqdm.auto import tqdm
|
| 63 |
+
|
| 64 |
+
from model import Generator, gen_config, snake_order
|
| 65 |
+
from sample import load_checkpoint, sample_batch
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
# Physical constants
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
J = 1.0
|
| 73 |
+
T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact 2D Ising T_c β 2.2692
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
# Ising energy (pure JAX β no model parameters, no gradient needed)
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
def ising_energy_per_spin(token_ids: jax.Array) -> jax.Array:
|
| 81 |
+
"""Ising energy per spin from a snake-ordered token sequence {0, 1}.
|
| 82 |
+
|
| 83 |
+
Hamiltonian: H = βJ Ξ£_{β¨ijβ©} s_i s_j with periodic boundary conditions.
|
| 84 |
+
Returns a scalar.
|
| 85 |
+
"""
|
| 86 |
+
L = gen_config["lattice_size"]
|
| 87 |
+
rows, cols = snake_order(L) # concrete at trace time
|
| 88 |
+
spins = (token_ids * 2 - 1).astype(jnp.float32) # {0,1} β {β1,+1}
|
| 89 |
+
grid = jnp.zeros((L, L)).at[jnp.asarray(rows), jnp.asarray(cols)].set(spins)
|
| 90 |
+
right = jnp.roll(grid, -1, axis=1)
|
| 91 |
+
down = jnp.roll(grid, -1, axis=0)
|
| 92 |
+
return -J * (grid * right + grid * down).sum() / (L * L)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# VI loss (REINFORCE with per-batch baseline)
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
@eqx.filter_value_and_grad(has_aux=True)
|
| 100 |
+
def compute_vi_loss(
|
| 101 |
+
model,
|
| 102 |
+
token_ids: jax.Array,
|
| 103 |
+
T: float,
|
| 104 |
+
) -> tuple[jax.Array, dict]:
|
| 105 |
+
"""REINFORCE proxy loss for β_ΞΈ F[q].
|
| 106 |
+
|
| 107 |
+
The returned scalar is *not* the free energy β it is the REINFORCE
|
| 108 |
+
surrogate whose gradient equals β_ΞΈ F/T. Use aux["f"] to track F.
|
| 109 |
+
|
| 110 |
+
Parameters
|
| 111 |
+
----------
|
| 112 |
+
model : Generator
|
| 113 |
+
token_ids : int array (batch, seq_len) samples drawn from q_ΞΈ
|
| 114 |
+
T : target temperature
|
| 115 |
+
|
| 116 |
+
Returns
|
| 117 |
+
-------
|
| 118 |
+
(loss, aux), grads
|
| 119 |
+
aux keys: e, h, f (per spin), |m|, reward_std (variance diagnostic)
|
| 120 |
+
"""
|
| 121 |
+
N = gen_config["lattice_size"] ** 2
|
| 122 |
+
|
| 123 |
+
# ββ log q(s) via teacher-forced forward pass βββββββββββββββββββββββββββββ
|
| 124 |
+
# Disable dropout so we get the exact model log-probability.
|
| 125 |
+
# in_axes=(0, None, None): vmap over batch; broadcast enable_dropout and key.
|
| 126 |
+
logits = jax.vmap(model, in_axes=(0, None, None))(
|
| 127 |
+
{"token_ids": token_ids}, False, None
|
| 128 |
+
) # (batch, seq_len, state_size)
|
| 129 |
+
|
| 130 |
+
# Ξ£_t log p(s_t | s_{<t}) β summed over the sequence axis
|
| 131 |
+
log_q = -optax.softmax_cross_entropy_with_integer_labels(
|
| 132 |
+
logits[:, :-1, :], token_ids[:, 1:]
|
| 133 |
+
).sum(axis=-1) # (batch,)
|
| 134 |
+
|
| 135 |
+
# ββ Ising energies ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
# jax.lax.stop_gradient keeps energies out of the autodiff graph.
|
| 137 |
+
energies = jax.lax.stop_gradient(
|
| 138 |
+
jax.vmap(ising_energy_per_spin)(token_ids) # (batch,)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# ββ REINFORCE reward R = E/T + log q(s) βββββββββββββββββββββββββββββββββ
|
| 142 |
+
reward = jax.lax.stop_gradient(energies / T + log_q)
|
| 143 |
+
baseline = reward.mean()
|
| 144 |
+
|
| 145 |
+
# Proxy loss: β loss = βE_q[(R β b) Β· β log q] = β F/T
|
| 146 |
+
loss = jnp.mean(jax.lax.stop_gradient(reward - baseline) * (-log_q))
|
| 147 |
+
|
| 148 |
+
# ββ Diagnostics (all stop-gradiented; no effect on training) βββββββββββββ
|
| 149 |
+
e = energies.mean() # mean energy per spin
|
| 150 |
+
h = -log_q.mean() / N # entropy per spin (nats)
|
| 151 |
+
f = e - T * h # Helmholtz free energy per spin
|
| 152 |
+
m = jnp.abs((token_ids * 2 - 1).astype(jnp.float32).mean(axis=-1)).mean()
|
| 153 |
+
aux = {
|
| 154 |
+
"e": e,
|
| 155 |
+
"h": h,
|
| 156 |
+
"f": f,
|
| 157 |
+
"|m|": m,
|
| 158 |
+
"reward_std": reward.std(), # REINFORCE variance diagnostic
|
| 159 |
+
}
|
| 160 |
+
return loss, aux
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# Single training step (JIT-compiled; does NOT include sampling)
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
@eqx.filter_jit
|
| 168 |
+
def vi_step(
|
| 169 |
+
model,
|
| 170 |
+
token_ids: jax.Array,
|
| 171 |
+
opt_state,
|
| 172 |
+
tx,
|
| 173 |
+
T: float,
|
| 174 |
+
):
|
| 175 |
+
"""Compute VI loss + gradient and apply one optimiser update.
|
| 176 |
+
|
| 177 |
+
Sampling is intentionally excluded so you can profile / replace it
|
| 178 |
+
without re-compiling the gradient computation.
|
| 179 |
+
"""
|
| 180 |
+
(loss, aux), grads = compute_vi_loss(model, token_ids, T)
|
| 181 |
+
updates, opt_state = tx.update(grads, opt_state, model)
|
| 182 |
+
model = eqx.apply_updates(model, updates)
|
| 183 |
+
return loss, aux, model, opt_state
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
# Sampling helper
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
|
| 190 |
+
_SAMPLE_BATCH = 4 # fixed call-site batch; changing triggers recompilation
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def draw_samples(model, n: int, key: jax.Array) -> jax.Array:
|
| 194 |
+
"""Sample n configurations from the model in fixed-size batches.
|
| 195 |
+
|
| 196 |
+
Returns a jnp int32 array of shape (n, LΒ²).
|
| 197 |
+
"""
|
| 198 |
+
all_tokens = []
|
| 199 |
+
n_calls = -(-n // _SAMPLE_BATCH) # ceiling division
|
| 200 |
+
for _ in range(n_calls):
|
| 201 |
+
key, subkey = jax.random.split(key)
|
| 202 |
+
all_tokens.append(np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey)))
|
| 203 |
+
return jnp.asarray(np.concatenate(all_tokens)[:n])
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ---------------------------------------------------------------------------
|
| 207 |
+
# CLI
|
| 208 |
+
# ---------------------------------------------------------------------------
|
| 209 |
+
|
| 210 |
+
def parse_args():
|
| 211 |
+
p = argparse.ArgumentParser(
|
| 212 |
+
description="Variational inference fine-tuning of an Ising Generator.",
|
| 213 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 214 |
+
)
|
| 215 |
+
p.add_argument("--checkpoint", type=Path, default=None,
|
| 216 |
+
help="Warm-start from this .eqx file (strongly recommended).")
|
| 217 |
+
p.add_argument("--output-checkpoint", type=Path, default=None,
|
| 218 |
+
help="Save final model to this path.")
|
| 219 |
+
p.add_argument("--num-steps", type=int, default=200)
|
| 220 |
+
p.add_argument("--batch-size", type=int, default=16,
|
| 221 |
+
help="Configurations sampled from q per gradient step.")
|
| 222 |
+
p.add_argument("--learning-rate", type=float, default=1e-4)
|
| 223 |
+
p.add_argument("--temperature", type=float, default=T_C,
|
| 224 |
+
help=f"Target Boltzmann temperature (default: T_c β {T_C:.4f}).")
|
| 225 |
+
p.add_argument("--log-every", type=int, default=1)
|
| 226 |
+
p.add_argument("--seed", type=int, default=0)
|
| 227 |
+
return p.parse_args()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def main():
|
| 231 |
+
args = parse_args()
|
| 232 |
+
T = args.temperature
|
| 233 |
+
key = jax.random.PRNGKey(args.seed)
|
| 234 |
+
|
| 235 |
+
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 236 |
+
if args.checkpoint is not None:
|
| 237 |
+
print(f"Loading checkpoint from {args.checkpoint} β¦")
|
| 238 |
+
model = load_checkpoint(args.checkpoint)
|
| 239 |
+
else:
|
| 240 |
+
print("Initialising model from scratch.")
|
| 241 |
+
print(" Tip: use --checkpoint to warm-start from a CE-pretrained model.")
|
| 242 |
+
key, model_key = jax.random.split(key)
|
| 243 |
+
model = Generator(config=gen_config, key=model_key)
|
| 244 |
+
|
| 245 |
+
# ββ Optimiser βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 246 |
+
tx = optax.chain(
|
| 247 |
+
optax.clip_by_global_norm(1.0),
|
| 248 |
+
optax.adam(learning_rate=args.learning_rate),
|
| 249 |
+
)
|
| 250 |
+
tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model))
|
| 251 |
+
opt_state = tx.init(model)
|
| 252 |
+
|
| 253 |
+
L = gen_config["lattice_size"]
|
| 254 |
+
print(f"\nVI training | steps={args.num_steps} "
|
| 255 |
+
f"batch={args.batch_size} T={T:.4f} lr={args.learning_rate} L={L}")
|
| 256 |
+
print(" columns: e = β¨E/Nβ© h = ββ¨log qβ©/N "
|
| 257 |
+
"f = eβTΒ·h (minimised) |m| = mean |magnetisation|\n")
|
| 258 |
+
|
| 259 |
+
# ββ Training loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 260 |
+
with tqdm(range(args.num_steps), unit="steps") as pbar:
|
| 261 |
+
for step in pbar:
|
| 262 |
+
# 1. Sample from current model (bottleneck on CPU)
|
| 263 |
+
key, sample_key = jax.random.split(key)
|
| 264 |
+
token_ids = draw_samples(model, args.batch_size, sample_key)
|
| 265 |
+
|
| 266 |
+
# 2. VI gradient step (JIT-compiled teacher-forced forward pass)
|
| 267 |
+
loss, aux, model, opt_state = vi_step(
|
| 268 |
+
model, token_ids, opt_state, tx, T
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if step % args.log_every == 0:
|
| 272 |
+
pbar.set_postfix(
|
| 273 |
+
e = f"{float(aux['e']):.4f}",
|
| 274 |
+
h = f"{float(aux['h']):.4f}",
|
| 275 |
+
f = f"{float(aux['f']):.4f}",
|
| 276 |
+
m = f"{float(aux['|m|']):.3f}",
|
| 277 |
+
Rstd= f"{float(aux['reward_std']):.3f}",
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 281 |
+
if args.output_checkpoint is not None:
|
| 282 |
+
args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 283 |
+
eqx.tree_serialise_leaves(args.output_checkpoint, model)
|
| 284 |
+
print(f"\nSaved checkpoint β {args.output_checkpoint}")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
main()
|