import shlex import subprocess import sys from pathlib import Path import gradio as gr import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- OUT = Path("outputs") CE_CKPT = OUT / "ce_checkpoint.eqx" VI_CKPT = OUT / "vi_checkpoint.eqx" CE_SMPL = OUT / "samples_ce.npy" VI_SMPL = OUT / "samples_vi.npy" TRAIN_DATA = Path("spins.npy") TEST_DATA = Path("spins_test.npy") # --------------------------------------------------------------------------- # Subprocess helpers # --------------------------------------------------------------------------- def _stream_into(cmd: list[str], log_lines: list[str]): """Run cmd, append each stdout line to log_lines, yield log_lines after each line. stderr is merged into stdout so tracebacks are always visible. Yields the joined log after every line so callers can stream updates. """ log_lines.append("$ " + " ".join(shlex.quote(p) for p in cmd)) proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, ) assert proc.stdout is not None for line in proc.stdout: log_lines.append(line.rstrip()) yield "\n".join(log_lines[-300:]) rc = proc.wait() log_lines.append(f"[exit {rc}]") yield "\n".join(log_lines[-300:]) # --------------------------------------------------------------------------- # Sample grid figure # --------------------------------------------------------------------------- def _samples_figure(path: Path, title: str, n: int = 16) -> plt.Figure | None: if not path.exists(): return None grids = np.load(path).astype(np.float32)[:n] # (N, L, L), values ±1 cols = min(8, len(grids)) rows = (len(grids) + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.4, rows * 1.4)) axes = np.array(axes).reshape(-1) mags = grids.mean(axis=(1, 2)) for i, ax in enumerate(axes): if i < len(grids): ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest") ax.set_title(f"m={mags[i]:.2f}", fontsize=6) ax.axis("off") fig.suptitle(title, fontsize=9) plt.tight_layout() return fig # --------------------------------------------------------------------------- # Tab 1 – Cross-entropy training # --------------------------------------------------------------------------- def run_ce(mode, epochs, batch_size, lr, max_steps): OUT.mkdir(parents=True, exist_ok=True) cmd = [ sys.executable, "train.py", "--data", str(TRAIN_DATA), "--batch-size", str(int(batch_size)), "--learning-rate", str(lr), "--output-checkpoint", str(CE_CKPT), ] if mode == "Smoke": cmd += ["--epochs", "1", "--max-train-steps", "5", "--max-eval-batches", "2"] else: cmd += ["--epochs", str(int(epochs))] if int(max_steps) > 0: cmd += ["--max-train-steps", str(int(max_steps))] log_lines: list[str] = [] for log in _stream_into(cmd, log_lines): yield log, None ckpt = str(CE_CKPT) if CE_CKPT.exists() else None yield "\n".join(log_lines[-300:]), ckpt # --------------------------------------------------------------------------- # Tab 2 – Variational inference fine-tuning # --------------------------------------------------------------------------- def run_vi(mode, steps, batch_size, lr, warm_start): OUT.mkdir(parents=True, exist_ok=True) if warm_start and not CE_CKPT.exists(): yield "⚠ CE checkpoint not found. Run CE training first, or uncheck warm-start.", None return cmd = [ sys.executable, "vi_train.py", "--batch-size", str(int(batch_size)), "--learning-rate", str(lr), "--output-checkpoint", str(VI_CKPT), ] if warm_start and CE_CKPT.exists(): cmd += ["--checkpoint", str(CE_CKPT)] if mode == "Smoke": cmd += ["--num-steps", "3", "--log-every", "1"] else: cmd += ["--num-steps", str(int(steps))] log_lines: list[str] = [] for log in _stream_into(cmd, log_lines): yield log, None ckpt = str(VI_CKPT) if VI_CKPT.exists() else None yield "\n".join(log_lines[-300:]), ckpt # --------------------------------------------------------------------------- # Tab 3 – Sample & Eval # --------------------------------------------------------------------------- def run_eval(which, num_samples, seed): OUT.mkdir(parents=True, exist_ok=True) log_lines: list[str] = [] def current_log(): return "\n".join(log_lines[-300:]) run_ce_ = which in ("CE", "Both") run_vi_ = which in ("VI", "Both") if run_ce_ and not CE_CKPT.exists(): yield "⚠ CE checkpoint not found. Run CE training first.", None, None return if run_vi_ and not VI_CKPT.exists(): yield "⚠ VI checkpoint not found. Run VI training first.", None, None return # ── Generate samples ─────────────────────────────────────────────────── for ckpt, out_path, label in [ (CE_CKPT, CE_SMPL, "CE"), (VI_CKPT, VI_SMPL, "VI"), ]: if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_): continue log_lines.append(f"\n── Generating {num_samples} {label} samples ──") yield current_log(), None, None cmd = [ sys.executable, "sample.py", "--checkpoint", str(ckpt), "--num-samples", str(int(num_samples)), "--output", str(out_path), "--seed", str(int(seed)), ] for log in _stream_into(cmd, log_lines): yield log, None, None # ── Run eval ────────────────────────────────────────────────────────── for ckpt, smpl, label in [ (CE_CKPT, CE_SMPL, "CE"), (VI_CKPT, VI_SMPL, "VI"), ]: if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_): continue if not smpl.exists(): log_lines.append(f"⚠ {smpl} not found — sample generation may have failed.") yield current_log(), None, None continue log_lines.append(f"\n── Evaluating {label} model ──") yield current_log(), None, None cmd = [ sys.executable, "eval.py", "--checkpoint", str(ckpt), "--test-data", str(TEST_DATA), "--num-samples", str(int(num_samples)), "--samples-file", str(smpl), "--seed", str(int(seed)), ] for log in _stream_into(cmd, log_lines): yield log, None, None # ── Build figures ────────────────────────────────────────────────────── ce_fig = _samples_figure(CE_SMPL, "CE samples") if run_ce_ else None vi_fig = _samples_figure(VI_SMPL, "VI samples") if run_vi_ else None yield current_log(), ce_fig, vi_fig # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks(title="Ising Transformer") as demo: gr.Markdown( "# 2D Ising Transformer\n" "Autoregressive transformer trained on the 2D Ising model at the critical " "temperature T_c ≈ 2.27. " "Run **CE training** first, optionally fine-tune with **Variational Inference**, " "then **Sample & Eval** to compare both against the held-out test set." ) with gr.Tabs(): # ── Tab 1: CE training ────────────────────────────────────────────── with gr.Tab("Cross-Entropy Training"): gr.Markdown( "Trains the model to maximise `log q(s)` on the training spin " "configurations (teacher forcing, causal attention). " "A *Smoke* run does 5 steps to verify everything compiles." ) with gr.Row(): ce_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode") ce_epoch = gr.Number(value=10, precision=0, minimum=1, label="Epochs") ce_bs = gr.Number(value=32, precision=0, minimum=1, label="Batch size") with gr.Row(): ce_lr = gr.Number(value=1e-4, label="Learning rate") ce_steps = gr.Number(value=0, precision=0, minimum=0, label="Max steps (0 = no cap)") ce_run = gr.Button("Run CE Training", variant="primary") ce_logs = gr.Textbox(label="Logs", lines=20, max_lines=30) ce_ckpt = gr.File(label="Checkpoint") ce_run.click( run_ce, inputs=[ce_mode, ce_epoch, ce_bs, ce_lr, ce_steps], outputs=[ce_logs, ce_ckpt], ) # ── Tab 2: VI fine-tuning ───────────────────────────────────────── with gr.Tab("Variational Inference Fine-tuning"): gr.Markdown( "Minimises the variational free energy `F = ⟨E(s)⟩ − T·H[q]` using " "the REINFORCE gradient estimator. Warm-starting from the CE " "checkpoint is strongly recommended. " "A *Smoke* run does 3 steps." ) with gr.Row(): vi_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode") vi_steps = gr.Number(value=200, precision=0, minimum=1, label="Steps") vi_bs = gr.Number(value=16, precision=0, minimum=1, label="Batch size") with gr.Row(): vi_lr = gr.Number(value=1e-4, label="Learning rate") vi_warm = gr.Checkbox(value=True, label="Warm-start from CE checkpoint") vi_run = gr.Button("Run VI Fine-tuning", variant="primary") vi_logs = gr.Textbox(label="Logs", lines=20, max_lines=30) vi_ckpt = gr.File(label="Checkpoint") vi_run.click( run_vi, inputs=[vi_mode, vi_steps, vi_bs, vi_lr, vi_warm], outputs=[vi_logs, vi_ckpt], ) # ── Tab 3: Sample & Eval ────────────────────────────────────────── with gr.Tab("Sample & Eval"): gr.Markdown( "Generates spin configurations from the selected model(s), then runs " "the physical-observable evaluation against `spins_test.npy` " "(a held-out set, never seen during training).\n\n" "Features compared: magnetisation, energy, two-point correlations, " "cluster statistics. Distance reported as **Mahalanobis D** in the " "decorrelated feature space." ) with gr.Row(): ev_which = gr.Radio( ["CE", "VI", "Both"], value="Both", label="Model(s) to evaluate" ) ev_n = gr.Number(value=64, precision=0, minimum=4, label="Num samples") ev_seed = gr.Number(value=0, precision=0, label="Seed") ev_run = gr.Button("Run Sample & Eval", variant="primary") ev_logs = gr.Textbox(label="Logs", lines=20, max_lines=30) with gr.Row(): ev_ce_img = gr.Plot(label="CE samples") ev_vi_img = gr.Plot(label="VI samples") ev_run.click( run_eval, inputs=[ev_which, ev_n, ev_seed], outputs=[ev_logs, ev_ce_img, ev_vi_img], ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch(theme=gr.themes.Soft())