ising-transformer / vi_train.py
bertran-yorro's picture
Fix REINFORCE sign: minimize F, not maximize
3b17899 verified
#!/usr/bin/env python
# /// script
# dependencies = [
# "jax[cuda12]",
# "equinox",
# "optax",
# "tqdm",
# "jaxtyping",
# ]
# ///
"""Variational inference fine-tuning of the Ising Generator.
Objective
---------
Minimise the variational free energy
F[q] = E_{s~q}[E(s)] βˆ’ T Β· H[q]
= T Β· E_{s~q}[ E(s)/T + log q(s) ]
which equals KL(q βˆ₯ p*) up to the constant log Z, where
p*(s) ∝ exp(βˆ’E(s)/T) is the Ising Boltzmann distribution.
As F decreases, the model q approaches the correct physics.
Gradient estimator (REINFORCE / score-function)
------------------------------------------------
βˆ‡_ΞΈ F = T Β· E_{s~q}[( E(s)/T + log q(s) βˆ’ b ) Β· βˆ‡_ΞΈ log q(s)]
where b = batch-mean reward is a zero-variance control variate.
Per training step
-----------------
1. Sample a batch of configs from the current model q_ΞΈ (slow on CPU)
2. Compute Ising energy E(s) for each sample (fast, no model)
3. Compute log q(s) via a teacher-forced forward pass (fast, one pass)
4. Assemble reward R = E/T + log q, subtract baseline, backprop (fast)
Speed note
----------
Step 1 dominates on CPU (~5 s / sample on a 32Γ—32 lattice). On GPU it is
typically 10–100Γ— faster and VI training with batch_size β‰₯ 32 is practical.
The --checkpoint flag warm-starts from a CE-pretrained model, which dramati-
cally reduces the number of VI steps needed to converge.
Monitoring
----------
At each step the following quantities are logged:
e = ⟨E/N⟩ mean energy per spin (converges to data ~βˆ’1.45)
h = βˆ’βŸ¨log q⟩/N entropy per spin in nats (random init β‰ˆ 0.693)
f = e βˆ’ TΒ·h Helmholtz free energy per spin (we minimise this)
|m| = ⟨|Σs / N|⟩ mean absolute magnetisation
"""
import argparse
from pathlib import Path
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
from tqdm.auto import tqdm
from model import Generator, gen_config, snake_order
from sample import load_checkpoint, sample_batch
# ---------------------------------------------------------------------------
# Physical constants
# ---------------------------------------------------------------------------
J = 1.0
T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact 2D Ising T_c β‰ˆ 2.2692
# ---------------------------------------------------------------------------
# Ising energy (pure JAX β€” no model parameters, no gradient needed)
# ---------------------------------------------------------------------------
def ising_energy_per_spin(token_ids: jax.Array) -> jax.Array:
"""Ising energy per spin from a snake-ordered token sequence {0, 1}.
Hamiltonian: H = βˆ’J Ξ£_{⟨ij⟩} s_i s_j with periodic boundary conditions.
Returns a scalar.
"""
L = gen_config["lattice_size"]
rows, cols = snake_order(L) # concrete at trace time
spins = (token_ids * 2 - 1).astype(jnp.float32) # {0,1} β†’ {βˆ’1,+1}
grid = jnp.zeros((L, L)).at[jnp.asarray(rows), jnp.asarray(cols)].set(spins)
right = jnp.roll(grid, -1, axis=1)
down = jnp.roll(grid, -1, axis=0)
return -J * (grid * right + grid * down).sum() / (L * L)
# ---------------------------------------------------------------------------
# VI loss (REINFORCE with per-batch baseline)
# ---------------------------------------------------------------------------
@eqx.filter_value_and_grad(has_aux=True)
def compute_vi_loss(
model,
token_ids: jax.Array,
T: float,
) -> tuple[jax.Array, dict]:
"""REINFORCE proxy loss for βˆ‡_ΞΈ F[q].
The returned scalar is *not* the free energy β€” it is the REINFORCE
surrogate whose gradient equals βˆ‡_ΞΈ F/T. Use aux["f"] to track F.
Parameters
----------
model : Generator
token_ids : int array (batch, seq_len) samples drawn from q_ΞΈ
T : target temperature
Returns
-------
(loss, aux), grads
aux keys: e, h, f (per spin), |m|, reward_std (variance diagnostic)
"""
N = gen_config["lattice_size"] ** 2
# ── log q(s) via teacher-forced forward pass ─────────────────────────────
# Disable dropout so we get the exact model log-probability.
# in_axes=(0, None, None): vmap over batch; broadcast enable_dropout and key.
logits = jax.vmap(model, in_axes=(0, None, None))(
{"token_ids": token_ids}, False, None
) # (batch, seq_len, state_size)
# Ξ£_t log p(s_t | s_{<t}) β€” summed over the sequence axis
log_q = -optax.softmax_cross_entropy_with_integer_labels(
logits[:, :-1, :], token_ids[:, 1:]
).sum(axis=-1) # (batch,)
# ── Ising energies ────────────────────────────────────────────────────────
# jax.lax.stop_gradient keeps energies out of the autodiff graph.
energies = jax.lax.stop_gradient(
jax.vmap(ising_energy_per_spin)(token_ids) # (batch,)
)
# ── REINFORCE reward R = E/T + log q(s) ─────────────────────────────────
reward = jax.lax.stop_gradient(energies / T + log_q)
baseline = reward.mean()
# Whiten the reward: subtract mean AND divide by std.
# This makes the gradient magnitude independent of reward scale and
# prevents the high-variance collapse where all rewards are near-equal
# (mode collapse → std→0 → gradient→0 with mean-only baseline).
reward_norm = (reward - baseline) / (reward.std() + 1e-8)
# Proxy loss: βˆ‡_ΞΈ loss = E_q[RΜ‚ Β· βˆ‡_ΞΈ log q] = βˆ‡_ΞΈ (F/T) (up to scale)
# Minimising this via gradient descent drives ΞΈ toward lower free energy.
# NOTE: no negation β€” R is a cost to minimise, not a reward to maximise.
loss = jnp.mean(jax.lax.stop_gradient(reward_norm) * log_q)
# ── Diagnostics (all stop-gradiented; no effect on training) ─────────────
e = energies.mean() # mean energy per spin
h = -log_q.mean() / N # entropy per spin (nats)
f = e - T * h # Helmholtz free energy per spin
m = jnp.abs((token_ids * 2 - 1).astype(jnp.float32).mean(axis=-1)).mean()
aux = {
"e": e,
"h": h,
"f": f,
"|m|": m,
"reward_std": reward.std(), # REINFORCE variance diagnostic
}
return loss, aux
# ---------------------------------------------------------------------------
# Single training step (JIT-compiled; does NOT include sampling)
# ---------------------------------------------------------------------------
@eqx.filter_jit
def vi_step(
model,
token_ids: jax.Array,
opt_state,
tx,
T: float,
):
"""Compute VI loss + gradient and apply one optimiser update.
Sampling is intentionally excluded so you can profile / replace it
without re-compiling the gradient computation.
"""
(loss, aux), grads = compute_vi_loss(model, token_ids, T)
updates, opt_state = tx.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return loss, aux, model, opt_state
# ---------------------------------------------------------------------------
# Sampling helper
# ---------------------------------------------------------------------------
_SAMPLE_BATCH = 4 # fixed call-site batch; changing triggers recompilation
def draw_samples(model, n: int, key: jax.Array) -> jax.Array:
"""Sample n configurations from the model in fixed-size batches.
Returns a jnp int32 array of shape (n, LΒ²).
"""
all_tokens = []
n_calls = -(-n // _SAMPLE_BATCH) # ceiling division
for _ in range(n_calls):
key, subkey = jax.random.split(key)
all_tokens.append(np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey)))
return jnp.asarray(np.concatenate(all_tokens)[:n])
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(
description="Variational inference fine-tuning of an Ising Generator.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument("--checkpoint", type=Path, default=None,
help="Warm-start from this .eqx file (strongly recommended).")
p.add_argument("--output-checkpoint", type=Path, default=None,
help="Save final model to this path.")
p.add_argument("--num-steps", type=int, default=200)
p.add_argument("--batch-size", type=int, default=16,
help="Configurations sampled from q per gradient step.")
p.add_argument("--learning-rate", type=float, default=1e-4)
p.add_argument("--temperature", type=float, default=T_C,
help=f"Target Boltzmann temperature (default: T_c β‰ˆ {T_C:.4f}).")
p.add_argument("--log-every", type=int, default=1)
p.add_argument("--seed", type=int, default=0)
return p.parse_args()
def main():
args = parse_args()
T = args.temperature
key = jax.random.PRNGKey(args.seed)
# ── Model ─────────────────────────────────────────────────────────────────
if args.checkpoint is not None:
print(f"Loading checkpoint from {args.checkpoint} …")
model = load_checkpoint(args.checkpoint)
else:
print("Initialising model from scratch.")
print(" Tip: use --checkpoint to warm-start from a CE-pretrained model.")
key, model_key = jax.random.split(key)
model = Generator(config=gen_config, key=model_key)
# ── Optimiser ─────────────────────────────────────────────────────────────
tx = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=args.learning_rate),
)
tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model))
opt_state = tx.init(model)
L = gen_config["lattice_size"]
print(f"\nVI training | steps={args.num_steps} "
f"batch={args.batch_size} T={T:.4f} lr={args.learning_rate} L={L}")
print(" columns: e = ⟨E/N⟩ h = βˆ’βŸ¨log q⟩/N "
"f = eβˆ’TΒ·h (minimised) |m| = mean |magnetisation|\n")
# ── Training loop ─────────────────────────────────────────────────────────
with tqdm(range(args.num_steps), unit="steps") as pbar:
for step in pbar:
# 1. Sample from current model (bottleneck on CPU)
key, sample_key = jax.random.split(key)
token_ids = draw_samples(model, args.batch_size, sample_key)
# 2. VI gradient step (JIT-compiled teacher-forced forward pass)
loss, aux, model, opt_state = vi_step(
model, token_ids, opt_state, tx, T
)
if step % args.log_every == 0:
pbar.set_postfix(
e = f"{float(aux['e']):.4f}",
h = f"{float(aux['h']):.4f}",
f = f"{float(aux['f']):.4f}",
m = f"{float(aux['|m|']):.3f}",
Rstd= f"{float(aux['reward_std']):.3f}",
)
# ── Save ──────────────────────────────────────────────────────────────────
if args.output_checkpoint is not None:
args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
eqx.tree_serialise_leaves(args.output_checkpoint, model)
print(f"\nSaved checkpoint β†’ {args.output_checkpoint}")
if __name__ == "__main__":
main()