bertran-yorro commited on
Commit
2c223e2
Β·
verified Β·
1 Parent(s): 0f90baf

Fix: app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -34
app.py CHANGED
@@ -26,12 +26,15 @@ TEST_DATA = Path("spins_test.npy")
26
  # Subprocess helpers
27
  # ---------------------------------------------------------------------------
28
 
29
- def _stream(command: list[str]):
30
- """Run a command and yield log lines in real time."""
31
- log = ["$ " + " ".join(shlex.quote(p) for p in command), ""]
32
- yield "\n".join(log)
 
 
 
33
  proc = subprocess.Popen(
34
- command,
35
  stdout=subprocess.PIPE,
36
  stderr=subprocess.STDOUT,
37
  text=True,
@@ -39,11 +42,11 @@ def _stream(command: list[str]):
39
  )
40
  assert proc.stdout is not None
41
  for line in proc.stdout:
42
- log.append(line.rstrip())
43
- yield "\n".join(log[-300:])
44
  rc = proc.wait()
45
- log += ["", f"β€” exited {rc} β€”"]
46
- yield "\n".join(log[-300:])
47
 
48
 
49
  # ---------------------------------------------------------------------------
@@ -88,10 +91,11 @@ def run_ce(mode, epochs, batch_size, lr, max_steps):
88
  cmd += ["--epochs", str(int(epochs))]
89
  if int(max_steps) > 0:
90
  cmd += ["--max-train-steps", str(int(max_steps))]
91
- for log in _stream(cmd):
 
92
  yield log, None
93
  ckpt = str(CE_CKPT) if CE_CKPT.exists() else None
94
- yield log, ckpt
95
 
96
 
97
  # ---------------------------------------------------------------------------
@@ -115,10 +119,11 @@ def run_vi(mode, steps, batch_size, lr, warm_start):
115
  cmd += ["--num-steps", "3", "--log-every", "1"]
116
  else:
117
  cmd += ["--num-steps", str(int(steps))]
118
- for log in _stream(cmd):
 
119
  yield log, None
120
  ckpt = str(VI_CKPT) if VI_CKPT.exists() else None
121
- yield log, ckpt
122
 
123
 
124
  # ---------------------------------------------------------------------------
@@ -127,23 +132,19 @@ def run_vi(mode, steps, batch_size, lr, warm_start):
127
 
128
  def run_eval(which, num_samples, seed):
129
  OUT.mkdir(parents=True, exist_ok=True)
130
- log_lines = []
131
 
132
- def emit(msg=""):
133
- log_lines.append(msg)
134
- return (
135
- "\n".join(log_lines[-200:]),
136
- None, None, # CE figure, VI figure
137
- )
138
 
139
  run_ce_ = which in ("CE", "Both")
140
  run_vi_ = which in ("VI", "Both")
141
 
142
  if run_ce_ and not CE_CKPT.exists():
143
- yield emit("⚠ CE checkpoint not found. Run CE training first.")
144
  return
145
  if run_vi_ and not VI_CKPT.exists():
146
- yield emit("⚠ VI checkpoint not found. Run VI training first.")
147
  return
148
 
149
  # ── Generate samples ───────────────────────────────────────────────────
@@ -154,7 +155,7 @@ def run_eval(which, num_samples, seed):
154
  if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
155
  continue
156
  log_lines.append(f"\n── Generating {num_samples} {label} samples ──")
157
- yield "\n".join(log_lines[-200:]), None, None
158
  cmd = [
159
  sys.executable, "sample.py",
160
  "--checkpoint", str(ckpt),
@@ -162,9 +163,8 @@ def run_eval(which, num_samples, seed):
162
  "--output", str(out_path),
163
  "--seed", str(int(seed)),
164
  ]
165
- for chunk in _stream(cmd):
166
- log_lines[-1:] = chunk.splitlines()[-10:]
167
- yield "\n".join(log_lines[-200:]), None, None
168
 
169
  # ── Run eval ──────────────────────────────────────────────────────────
170
  for ckpt, smpl, label in [
@@ -173,31 +173,34 @@ def run_eval(which, num_samples, seed):
173
  ]:
174
  if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
175
  continue
 
 
 
 
176
  log_lines.append(f"\n── Evaluating {label} model ──")
177
- yield "\n".join(log_lines[-200:]), None, None
178
  cmd = [
179
  sys.executable, "eval.py",
180
  "--checkpoint", str(ckpt),
181
  "--test-data", str(TEST_DATA),
182
  "--num-samples", str(int(num_samples)),
183
- "--samples-file",str(smpl),
184
  "--seed", str(int(seed)),
185
  ]
186
- for chunk in _stream(cmd):
187
- log_lines[-1:] = chunk.splitlines()[-20:]
188
- yield "\n".join(log_lines[-200:]), None, None
189
 
190
  # ── Build figures ──────────────────────────────────────────────────────
191
  ce_fig = _samples_figure(CE_SMPL, "CE samples") if run_ce_ else None
192
  vi_fig = _samples_figure(VI_SMPL, "VI samples") if run_vi_ else None
193
- yield "\n".join(log_lines[-200:]), ce_fig, vi_fig
194
 
195
 
196
  # ---------------------------------------------------------------------------
197
  # Gradio UI
198
  # ---------------------------------------------------------------------------
199
 
200
- with gr.Blocks(title="Ising Transformer", theme=gr.themes.Soft()) as demo:
201
  gr.Markdown(
202
  "# 2D Ising Transformer\n"
203
  "Autoregressive transformer trained on the 2D Ising model at the critical "
@@ -284,4 +287,4 @@ with gr.Blocks(title="Ising Transformer", theme=gr.themes.Soft()) as demo:
284
 
285
 
286
  if __name__ == "__main__":
287
- demo.queue(default_concurrency_limit=1).launch()
 
26
  # Subprocess helpers
27
  # ---------------------------------------------------------------------------
28
 
29
+ def _stream_into(cmd: list[str], log_lines: list[str]):
30
+ """Run cmd, append each stdout line to log_lines, yield log_lines after each line.
31
+
32
+ stderr is merged into stdout so tracebacks are always visible.
33
+ Yields the joined log after every line so callers can stream updates.
34
+ """
35
+ log_lines.append("$ " + " ".join(shlex.quote(p) for p in cmd))
36
  proc = subprocess.Popen(
37
+ cmd,
38
  stdout=subprocess.PIPE,
39
  stderr=subprocess.STDOUT,
40
  text=True,
 
42
  )
43
  assert proc.stdout is not None
44
  for line in proc.stdout:
45
+ log_lines.append(line.rstrip())
46
+ yield "\n".join(log_lines[-300:])
47
  rc = proc.wait()
48
+ log_lines.append(f"[exit {rc}]")
49
+ yield "\n".join(log_lines[-300:])
50
 
51
 
52
  # ---------------------------------------------------------------------------
 
91
  cmd += ["--epochs", str(int(epochs))]
92
  if int(max_steps) > 0:
93
  cmd += ["--max-train-steps", str(int(max_steps))]
94
+ log_lines: list[str] = []
95
+ for log in _stream_into(cmd, log_lines):
96
  yield log, None
97
  ckpt = str(CE_CKPT) if CE_CKPT.exists() else None
98
+ yield "\n".join(log_lines[-300:]), ckpt
99
 
100
 
101
  # ---------------------------------------------------------------------------
 
119
  cmd += ["--num-steps", "3", "--log-every", "1"]
120
  else:
121
  cmd += ["--num-steps", str(int(steps))]
122
+ log_lines: list[str] = []
123
+ for log in _stream_into(cmd, log_lines):
124
  yield log, None
125
  ckpt = str(VI_CKPT) if VI_CKPT.exists() else None
126
+ yield "\n".join(log_lines[-300:]), ckpt
127
 
128
 
129
  # ---------------------------------------------------------------------------
 
132
 
133
  def run_eval(which, num_samples, seed):
134
  OUT.mkdir(parents=True, exist_ok=True)
135
+ log_lines: list[str] = []
136
 
137
+ def current_log():
138
+ return "\n".join(log_lines[-300:])
 
 
 
 
139
 
140
  run_ce_ = which in ("CE", "Both")
141
  run_vi_ = which in ("VI", "Both")
142
 
143
  if run_ce_ and not CE_CKPT.exists():
144
+ yield "⚠ CE checkpoint not found. Run CE training first.", None, None
145
  return
146
  if run_vi_ and not VI_CKPT.exists():
147
+ yield "⚠ VI checkpoint not found. Run VI training first.", None, None
148
  return
149
 
150
  # ── Generate samples ───────────────────────────────────────────────────
 
155
  if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
156
  continue
157
  log_lines.append(f"\n── Generating {num_samples} {label} samples ──")
158
+ yield current_log(), None, None
159
  cmd = [
160
  sys.executable, "sample.py",
161
  "--checkpoint", str(ckpt),
 
163
  "--output", str(out_path),
164
  "--seed", str(int(seed)),
165
  ]
166
+ for log in _stream_into(cmd, log_lines):
167
+ yield log, None, None
 
168
 
169
  # ── Run eval ──────────────────────────────────────────────────────────
170
  for ckpt, smpl, label in [
 
173
  ]:
174
  if (label == "CE" and not run_ce_) or (label == "VI" and not run_vi_):
175
  continue
176
+ if not smpl.exists():
177
+ log_lines.append(f"⚠ {smpl} not found β€” sample generation may have failed.")
178
+ yield current_log(), None, None
179
+ continue
180
  log_lines.append(f"\n── Evaluating {label} model ──")
181
+ yield current_log(), None, None
182
  cmd = [
183
  sys.executable, "eval.py",
184
  "--checkpoint", str(ckpt),
185
  "--test-data", str(TEST_DATA),
186
  "--num-samples", str(int(num_samples)),
187
+ "--samples-file", str(smpl),
188
  "--seed", str(int(seed)),
189
  ]
190
+ for log in _stream_into(cmd, log_lines):
191
+ yield log, None, None
 
192
 
193
  # ── Build figures ──────────────────────────────────────────────────────
194
  ce_fig = _samples_figure(CE_SMPL, "CE samples") if run_ce_ else None
195
  vi_fig = _samples_figure(VI_SMPL, "VI samples") if run_vi_ else None
196
+ yield current_log(), ce_fig, vi_fig
197
 
198
 
199
  # ---------------------------------------------------------------------------
200
  # Gradio UI
201
  # ---------------------------------------------------------------------------
202
 
203
+ with gr.Blocks(title="Ising Transformer") as demo:
204
  gr.Markdown(
205
  "# 2D Ising Transformer\n"
206
  "Autoregressive transformer trained on the 2D Ising model at the critical "
 
287
 
288
 
289
  if __name__ == "__main__":
290
+ demo.queue(default_concurrency_limit=1).launch(theme=gr.themes.Soft())