Spaces:
Sleeping
Sleeping
| import shlex | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| OUT = Path("outputs") | |
| CE_CKPT = OUT / "ce_checkpoint.eqx" | |
| VI_CKPT = OUT / "vi_checkpoint.eqx" | |
| CE_SMPL = OUT / "samples_ce.npy" | |
| VI_SMPL = OUT / "samples_vi.npy" | |
| TRAIN_DATA = Path("spins.npy") | |
| TEST_DATA = Path("spins_test.npy") | |
| # --------------------------------------------------------------------------- | |
| # Subprocess helpers | |
| # --------------------------------------------------------------------------- | |
| def _stream_into(cmd: list[str], log_lines: list[str]): | |
| """Run cmd, append each stdout line to log_lines, yield log_lines after each line. | |
| stderr is merged into stdout so tracebacks are always visible. | |
| Yields the joined log after every line so callers can stream updates. | |
| """ | |
| log_lines.append("$ " + " ".join(shlex.quote(p) for p in cmd)) | |
| proc = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| ) | |
| assert proc.stdout is not None | |
| for line in proc.stdout: | |
| log_lines.append(line.rstrip()) | |
| yield "\n".join(log_lines[-300:]) | |
| rc = proc.wait() | |
| log_lines.append(f"[exit {rc}]") | |
| yield "\n".join(log_lines[-300:]) | |
| # --------------------------------------------------------------------------- | |
| # Sample grid figure | |
| # --------------------------------------------------------------------------- | |
| def _samples_figure(path: Path, title: str, n: int = 16) -> plt.Figure | None: | |
| if not path.exists(): | |
| return None | |
| grids = np.load(path).astype(np.float32)[:n] # (N, L, L), values Β±1 | |
| cols = min(8, len(grids)) | |
| rows = (len(grids) + cols - 1) // cols | |
| fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.4, rows * 1.4)) | |
| axes = np.array(axes).reshape(-1) | |
| mags = grids.mean(axis=(1, 2)) | |
| for i, ax in enumerate(axes): | |
| if i < len(grids): | |
| ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest") | |
| ax.set_title(f"m={mags[i]:.2f}", fontsize=6) | |
| ax.axis("off") | |
| fig.suptitle(title, fontsize=9) | |
| plt.tight_layout() | |
| return fig | |
| # --------------------------------------------------------------------------- | |
| # Tab 1 β Cross-entropy training | |
| # --------------------------------------------------------------------------- | |
| def run_ce(mode, epochs, batch_size, lr, max_steps): | |
| OUT.mkdir(parents=True, exist_ok=True) | |
| cmd = [ | |
| sys.executable, "train.py", | |
| "--data", str(TRAIN_DATA), | |
| "--batch-size", str(int(batch_size)), | |
| "--learning-rate", str(lr), | |
| "--output-checkpoint", str(CE_CKPT), | |
| ] | |
| if mode == "Smoke": | |
| cmd += ["--epochs", "1", "--max-train-steps", "5", "--max-eval-batches", "2"] | |
| else: | |
| cmd += ["--epochs", str(int(epochs))] | |
| if int(max_steps) > 0: | |
| cmd += ["--max-train-steps", str(int(max_steps))] | |
| log_lines: list[str] = [] | |
| for log in _stream_into(cmd, log_lines): | |
| yield log, None | |
| ckpt = str(CE_CKPT) if CE_CKPT.exists() else None | |
| yield "\n".join(log_lines[-300:]), ckpt | |
| # --------------------------------------------------------------------------- | |
| # Tab 2 β Variational inference fine-tuning | |
| # --------------------------------------------------------------------------- | |
| def run_vi(mode, steps, batch_size, lr, warm_start): | |
| OUT.mkdir(parents=True, exist_ok=True) | |
| if warm_start and not CE_CKPT.exists(): | |
| yield "β CE checkpoint not found. Run CE training first, or uncheck warm-start.", None | |
| return | |
| cmd = [ | |
| sys.executable, "vi_train.py", | |
| "--batch-size", str(int(batch_size)), | |
| "--learning-rate", str(lr), | |
| "--output-checkpoint", str(VI_CKPT), | |
| ] | |
| if warm_start and CE_CKPT.exists(): | |
| cmd += ["--checkpoint", str(CE_CKPT)] | |
| if mode == "Smoke": | |
| cmd += ["--num-steps", "3", "--log-every", "1"] | |
| else: | |
| cmd += ["--num-steps", str(int(steps))] | |
| log_lines: list[str] = [] | |
| for log in _stream_into(cmd, log_lines): | |
| yield log, None | |
| ckpt = str(VI_CKPT) if VI_CKPT.exists() else None | |
| yield "\n".join(log_lines[-300:]), ckpt | |
| # --------------------------------------------------------------------------- | |
| # Tab 3 β Sample & Eval | |
| # --------------------------------------------------------------------------- | |
| def run_eval(which, num_samples, seed): | |
| OUT.mkdir(parents=True, exist_ok=True) | |
| log_lines: list[str] = [] | |
| def current_log(): | |
| return "\n".join(log_lines[-300:]) | |
| run_ce_ = which in ("CE", "Both") | |
| run_vi_ = which in ("VI", "Both") | |
| if run_ce_ and not CE_CKPT.exists(): | |
| yield "β CE checkpoint not found. Run CE training first.", None, None | |
| return | |
| if run_vi_ and not VI_CKPT.exists(): | |
| yield "β VI checkpoint not found. Run VI training first.", None, None | |
| return | |
| # ββ Generate samples βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| for ckpt, out_path, label in [ | |
| (CE_CKPT, CE_SMPL, "CE"), | |
| (VI_CKPT, VI_SMPL, "VI"), | |
| ]: | |
| if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_): | |
| continue | |
| log_lines.append(f"\nββ Generating {num_samples} {label} samples ββ") | |
| yield current_log(), None, None | |
| cmd = [ | |
| sys.executable, "sample.py", | |
| "--checkpoint", str(ckpt), | |
| "--num-samples", str(int(num_samples)), | |
| "--output", str(out_path), | |
| "--seed", str(int(seed)), | |
| ] | |
| for log in _stream_into(cmd, log_lines): | |
| yield log, None, None | |
| # ββ Run eval ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| for ckpt, smpl, label in [ | |
| (CE_CKPT, CE_SMPL, "CE"), | |
| (VI_CKPT, VI_SMPL, "VI"), | |
| ]: | |
| if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_): | |
| continue | |
| if not smpl.exists(): | |
| log_lines.append(f"β {smpl} not found β sample generation may have failed.") | |
| yield current_log(), None, None | |
| continue | |
| log_lines.append(f"\nββ Evaluating {label} model ββ") | |
| yield current_log(), None, None | |
| cmd = [ | |
| sys.executable, "eval.py", | |
| "--checkpoint", str(ckpt), | |
| "--test-data", str(TEST_DATA), | |
| "--num-samples", str(int(num_samples)), | |
| "--samples-file", str(smpl), | |
| "--seed", str(int(seed)), | |
| ] | |
| for log in _stream_into(cmd, log_lines): | |
| yield log, None, None | |
| # ββ Build figures ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ce_fig = _samples_figure(CE_SMPL, "CE samples") if run_ce_ else None | |
| vi_fig = _samples_figure(VI_SMPL, "VI samples") if run_vi_ else None | |
| yield current_log(), ce_fig, vi_fig | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Ising Transformer") as demo: | |
| gr.Markdown( | |
| "# 2D Ising Transformer\n" | |
| "Autoregressive transformer trained on the 2D Ising model at the critical " | |
| "temperature T_c β 2.27. " | |
| "Run **CE training** first, optionally fine-tune with **Variational Inference**, " | |
| "then **Sample & Eval** to compare both against the held-out test set." | |
| ) | |
| with gr.Tabs(): | |
| # ββ Tab 1: CE training ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Cross-Entropy Training"): | |
| gr.Markdown( | |
| "Trains the model to maximise `log q(s)` on the training spin " | |
| "configurations (teacher forcing, causal attention). " | |
| "A *Smoke* run does 5 steps to verify everything compiles." | |
| ) | |
| with gr.Row(): | |
| ce_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode") | |
| ce_epoch = gr.Number(value=10, precision=0, minimum=1, label="Epochs") | |
| ce_bs = gr.Number(value=32, precision=0, minimum=1, label="Batch size") | |
| with gr.Row(): | |
| ce_lr = gr.Number(value=1e-4, label="Learning rate") | |
| ce_steps = gr.Number(value=0, precision=0, minimum=0, label="Max steps (0 = no cap)") | |
| ce_run = gr.Button("Run CE Training", variant="primary") | |
| ce_logs = gr.Textbox(label="Logs", lines=20, max_lines=30) | |
| ce_ckpt = gr.File(label="Checkpoint") | |
| ce_run.click( | |
| run_ce, | |
| inputs=[ce_mode, ce_epoch, ce_bs, ce_lr, ce_steps], | |
| outputs=[ce_logs, ce_ckpt], | |
| ) | |
| # ββ Tab 2: VI fine-tuning βββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Variational Inference Fine-tuning"): | |
| gr.Markdown( | |
| "Minimises the variational free energy `F = β¨E(s)β© β TΒ·H[q]` using " | |
| "the REINFORCE gradient estimator. Warm-starting from the CE " | |
| "checkpoint is strongly recommended. " | |
| "A *Smoke* run does 3 steps." | |
| ) | |
| with gr.Row(): | |
| vi_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode") | |
| vi_steps = gr.Number(value=200, precision=0, minimum=1, label="Steps") | |
| vi_bs = gr.Number(value=16, precision=0, minimum=1, label="Batch size") | |
| with gr.Row(): | |
| vi_lr = gr.Number(value=1e-4, label="Learning rate") | |
| vi_warm = gr.Checkbox(value=True, label="Warm-start from CE checkpoint") | |
| vi_run = gr.Button("Run VI Fine-tuning", variant="primary") | |
| vi_logs = gr.Textbox(label="Logs", lines=20, max_lines=30) | |
| vi_ckpt = gr.File(label="Checkpoint") | |
| vi_run.click( | |
| run_vi, | |
| inputs=[vi_mode, vi_steps, vi_bs, vi_lr, vi_warm], | |
| outputs=[vi_logs, vi_ckpt], | |
| ) | |
| # ββ Tab 3: Sample & Eval ββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Sample & Eval"): | |
| gr.Markdown( | |
| "Generates spin configurations from the selected model(s), then runs " | |
| "the physical-observable evaluation against `spins_test.npy` " | |
| "(a held-out set, never seen during training).\n\n" | |
| "Features compared: magnetisation, energy, two-point correlations, " | |
| "cluster statistics. Distance reported as **Mahalanobis D** in the " | |
| "decorrelated feature space." | |
| ) | |
| with gr.Row(): | |
| ev_which = gr.Radio( | |
| ["CE", "VI", "Both"], value="Both", label="Model(s) to evaluate" | |
| ) | |
| ev_n = gr.Number(value=64, precision=0, minimum=4, label="Num samples") | |
| ev_seed = gr.Number(value=0, precision=0, label="Seed") | |
| ev_run = gr.Button("Run Sample & Eval", variant="primary") | |
| ev_logs = gr.Textbox(label="Logs", lines=20, max_lines=30) | |
| with gr.Row(): | |
| ev_ce_img = gr.Plot(label="CE samples") | |
| ev_vi_img = gr.Plot(label="VI samples") | |
| ev_run.click( | |
| run_eval, | |
| inputs=[ev_which, ev_n, ev_seed], | |
| outputs=[ev_logs, ev_ce_img, ev_vi_img], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch(theme=gr.themes.Soft()) | |