bertran-yorro's picture
Fix: app.py
2c223e2 verified
import shlex
import subprocess
import sys
from pathlib import Path
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
OUT = Path("outputs")
CE_CKPT = OUT / "ce_checkpoint.eqx"
VI_CKPT = OUT / "vi_checkpoint.eqx"
CE_SMPL = OUT / "samples_ce.npy"
VI_SMPL = OUT / "samples_vi.npy"
TRAIN_DATA = Path("spins.npy")
TEST_DATA = Path("spins_test.npy")
# ---------------------------------------------------------------------------
# Subprocess helpers
# ---------------------------------------------------------------------------
def _stream_into(cmd: list[str], log_lines: list[str]):
"""Run cmd, append each stdout line to log_lines, yield log_lines after each line.
stderr is merged into stdout so tracebacks are always visible.
Yields the joined log after every line so callers can stream updates.
"""
log_lines.append("$ " + " ".join(shlex.quote(p) for p in cmd))
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
assert proc.stdout is not None
for line in proc.stdout:
log_lines.append(line.rstrip())
yield "\n".join(log_lines[-300:])
rc = proc.wait()
log_lines.append(f"[exit {rc}]")
yield "\n".join(log_lines[-300:])
# ---------------------------------------------------------------------------
# Sample grid figure
# ---------------------------------------------------------------------------
def _samples_figure(path: Path, title: str, n: int = 16) -> plt.Figure | None:
if not path.exists():
return None
grids = np.load(path).astype(np.float32)[:n] # (N, L, L), values Β±1
cols = min(8, len(grids))
rows = (len(grids) + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.4, rows * 1.4))
axes = np.array(axes).reshape(-1)
mags = grids.mean(axis=(1, 2))
for i, ax in enumerate(axes):
if i < len(grids):
ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest")
ax.set_title(f"m={mags[i]:.2f}", fontsize=6)
ax.axis("off")
fig.suptitle(title, fontsize=9)
plt.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Tab 1 – Cross-entropy training
# ---------------------------------------------------------------------------
def run_ce(mode, epochs, batch_size, lr, max_steps):
OUT.mkdir(parents=True, exist_ok=True)
cmd = [
sys.executable, "train.py",
"--data", str(TRAIN_DATA),
"--batch-size", str(int(batch_size)),
"--learning-rate", str(lr),
"--output-checkpoint", str(CE_CKPT),
]
if mode == "Smoke":
cmd += ["--epochs", "1", "--max-train-steps", "5", "--max-eval-batches", "2"]
else:
cmd += ["--epochs", str(int(epochs))]
if int(max_steps) > 0:
cmd += ["--max-train-steps", str(int(max_steps))]
log_lines: list[str] = []
for log in _stream_into(cmd, log_lines):
yield log, None
ckpt = str(CE_CKPT) if CE_CKPT.exists() else None
yield "\n".join(log_lines[-300:]), ckpt
# ---------------------------------------------------------------------------
# Tab 2 – Variational inference fine-tuning
# ---------------------------------------------------------------------------
def run_vi(mode, steps, batch_size, lr, warm_start):
OUT.mkdir(parents=True, exist_ok=True)
if warm_start and not CE_CKPT.exists():
yield "⚠ CE checkpoint not found. Run CE training first, or uncheck warm-start.", None
return
cmd = [
sys.executable, "vi_train.py",
"--batch-size", str(int(batch_size)),
"--learning-rate", str(lr),
"--output-checkpoint", str(VI_CKPT),
]
if warm_start and CE_CKPT.exists():
cmd += ["--checkpoint", str(CE_CKPT)]
if mode == "Smoke":
cmd += ["--num-steps", "3", "--log-every", "1"]
else:
cmd += ["--num-steps", str(int(steps))]
log_lines: list[str] = []
for log in _stream_into(cmd, log_lines):
yield log, None
ckpt = str(VI_CKPT) if VI_CKPT.exists() else None
yield "\n".join(log_lines[-300:]), ckpt
# ---------------------------------------------------------------------------
# Tab 3 – Sample & Eval
# ---------------------------------------------------------------------------
def run_eval(which, num_samples, seed):
OUT.mkdir(parents=True, exist_ok=True)
log_lines: list[str] = []
def current_log():
return "\n".join(log_lines[-300:])
run_ce_ = which in ("CE", "Both")
run_vi_ = which in ("VI", "Both")
if run_ce_ and not CE_CKPT.exists():
yield "⚠ CE checkpoint not found. Run CE training first.", None, None
return
if run_vi_ and not VI_CKPT.exists():
yield "⚠ VI checkpoint not found. Run VI training first.", None, None
return
# ── Generate samples ───────────────────────────────────────────────────
for ckpt, out_path, label in [
(CE_CKPT, CE_SMPL, "CE"),
(VI_CKPT, VI_SMPL, "VI"),
]:
if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
continue
log_lines.append(f"\n── Generating {num_samples} {label} samples ──")
yield current_log(), None, None
cmd = [
sys.executable, "sample.py",
"--checkpoint", str(ckpt),
"--num-samples", str(int(num_samples)),
"--output", str(out_path),
"--seed", str(int(seed)),
]
for log in _stream_into(cmd, log_lines):
yield log, None, None
# ── Run eval ──────────────────────────────────────────────────────────
for ckpt, smpl, label in [
(CE_CKPT, CE_SMPL, "CE"),
(VI_CKPT, VI_SMPL, "VI"),
]:
if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
continue
if not smpl.exists():
log_lines.append(f"⚠ {smpl} not found β€” sample generation may have failed.")
yield current_log(), None, None
continue
log_lines.append(f"\n── Evaluating {label} model ──")
yield current_log(), None, None
cmd = [
sys.executable, "eval.py",
"--checkpoint", str(ckpt),
"--test-data", str(TEST_DATA),
"--num-samples", str(int(num_samples)),
"--samples-file", str(smpl),
"--seed", str(int(seed)),
]
for log in _stream_into(cmd, log_lines):
yield log, None, None
# ── Build figures ──────────────────────────────────────────────────────
ce_fig = _samples_figure(CE_SMPL, "CE samples") if run_ce_ else None
vi_fig = _samples_figure(VI_SMPL, "VI samples") if run_vi_ else None
yield current_log(), ce_fig, vi_fig
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Ising Transformer") as demo:
gr.Markdown(
"# 2D Ising Transformer\n"
"Autoregressive transformer trained on the 2D Ising model at the critical "
"temperature T_c β‰ˆ 2.27. "
"Run **CE training** first, optionally fine-tune with **Variational Inference**, "
"then **Sample & Eval** to compare both against the held-out test set."
)
with gr.Tabs():
# ── Tab 1: CE training ──────────────────────────────────────────────
with gr.Tab("Cross-Entropy Training"):
gr.Markdown(
"Trains the model to maximise `log q(s)` on the training spin "
"configurations (teacher forcing, causal attention). "
"A *Smoke* run does 5 steps to verify everything compiles."
)
with gr.Row():
ce_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode")
ce_epoch = gr.Number(value=10, precision=0, minimum=1, label="Epochs")
ce_bs = gr.Number(value=32, precision=0, minimum=1, label="Batch size")
with gr.Row():
ce_lr = gr.Number(value=1e-4, label="Learning rate")
ce_steps = gr.Number(value=0, precision=0, minimum=0, label="Max steps (0 = no cap)")
ce_run = gr.Button("Run CE Training", variant="primary")
ce_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
ce_ckpt = gr.File(label="Checkpoint")
ce_run.click(
run_ce,
inputs=[ce_mode, ce_epoch, ce_bs, ce_lr, ce_steps],
outputs=[ce_logs, ce_ckpt],
)
# ── Tab 2: VI fine-tuning ─────────────────────────────────────────
with gr.Tab("Variational Inference Fine-tuning"):
gr.Markdown(
"Minimises the variational free energy `F = ⟨E(s)⟩ βˆ’ TΒ·H[q]` using "
"the REINFORCE gradient estimator. Warm-starting from the CE "
"checkpoint is strongly recommended. "
"A *Smoke* run does 3 steps."
)
with gr.Row():
vi_mode = gr.Radio(["Smoke", "Full"], value="Smoke", label="Mode")
vi_steps = gr.Number(value=200, precision=0, minimum=1, label="Steps")
vi_bs = gr.Number(value=16, precision=0, minimum=1, label="Batch size")
with gr.Row():
vi_lr = gr.Number(value=1e-4, label="Learning rate")
vi_warm = gr.Checkbox(value=True, label="Warm-start from CE checkpoint")
vi_run = gr.Button("Run VI Fine-tuning", variant="primary")
vi_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
vi_ckpt = gr.File(label="Checkpoint")
vi_run.click(
run_vi,
inputs=[vi_mode, vi_steps, vi_bs, vi_lr, vi_warm],
outputs=[vi_logs, vi_ckpt],
)
# ── Tab 3: Sample & Eval ──────────────────────────────────────────
with gr.Tab("Sample & Eval"):
gr.Markdown(
"Generates spin configurations from the selected model(s), then runs "
"the physical-observable evaluation against `spins_test.npy` "
"(a held-out set, never seen during training).\n\n"
"Features compared: magnetisation, energy, two-point correlations, "
"cluster statistics. Distance reported as **Mahalanobis D** in the "
"decorrelated feature space."
)
with gr.Row():
ev_which = gr.Radio(
["CE", "VI", "Both"], value="Both", label="Model(s) to evaluate"
)
ev_n = gr.Number(value=64, precision=0, minimum=4, label="Num samples")
ev_seed = gr.Number(value=0, precision=0, label="Seed")
ev_run = gr.Button("Run Sample & Eval", variant="primary")
ev_logs = gr.Textbox(label="Logs", lines=20, max_lines=30)
with gr.Row():
ev_ce_img = gr.Plot(label="CE samples")
ev_vi_img = gr.Plot(label="VI samples")
ev_run.click(
run_eval,
inputs=[ev_which, ev_n, ev_seed],
outputs=[ev_logs, ev_ce_img, ev_vi_img],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch(theme=gr.themes.Soft())