#!/usr/bin/env python # /// script # dependencies = [ # "jax[cuda12]", # "equinox", # "optax", # "tqdm", # "jaxtyping", # ] # /// """Variational inference fine-tuning of the Ising Generator. Objective --------- Minimise the variational free energy F[q] = E_{s~q}[E(s)] − T · H[q] = T · E_{s~q}[ E(s)/T + log q(s) ] which equals KL(q ∥ p*) up to the constant log Z, where p*(s) ∝ exp(−E(s)/T) is the Ising Boltzmann distribution. As F decreases, the model q approaches the correct physics. Gradient estimator (REINFORCE / score-function) ------------------------------------------------ ∇_θ F = T · E_{s~q}[( E(s)/T + log q(s) − b ) · ∇_θ log q(s)] where b = batch-mean reward is a zero-variance control variate. Per training step ----------------- 1. Sample a batch of configs from the current model q_θ (slow on CPU) 2. Compute Ising energy E(s) for each sample (fast, no model) 3. Compute log q(s) via a teacher-forced forward pass (fast, one pass) 4. Assemble reward R = E/T + log q, subtract baseline, backprop (fast) Speed note ---------- Step 1 dominates on CPU (~5 s / sample on a 32×32 lattice). On GPU it is typically 10–100× faster and VI training with batch_size ≥ 32 is practical. The --checkpoint flag warm-starts from a CE-pretrained model, which dramati- cally reduces the number of VI steps needed to converge. Monitoring ---------- At each step the following quantities are logged: e = ⟨E/N⟩ mean energy per spin (converges to data ~−1.45) h = −⟨log q⟩/N entropy per spin in nats (random init ≈ 0.693) f = e − T·h Helmholtz free energy per spin (we minimise this) |m| = ⟨|Σs / N|⟩ mean absolute magnetisation """ import argparse from pathlib import Path import equinox as eqx import jax import jax.numpy as jnp import numpy as np import optax from tqdm.auto import tqdm from model import Generator, gen_config, snake_order from sample import load_checkpoint, sample_batch # --------------------------------------------------------------------------- # Physical constants # --------------------------------------------------------------------------- J = 1.0 T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact 2D Ising T_c ≈ 2.2692 # --------------------------------------------------------------------------- # Ising energy (pure JAX — no model parameters, no gradient needed) # --------------------------------------------------------------------------- def ising_energy_per_spin(token_ids: jax.Array) -> jax.Array: """Ising energy per spin from a snake-ordered token sequence {0, 1}. Hamiltonian: H = −J Σ_{⟨ij⟩} s_i s_j with periodic boundary conditions. Returns a scalar. """ L = gen_config["lattice_size"] rows, cols = snake_order(L) # concrete at trace time spins = (token_ids * 2 - 1).astype(jnp.float32) # {0,1} → {−1,+1} grid = jnp.zeros((L, L)).at[jnp.asarray(rows), jnp.asarray(cols)].set(spins) right = jnp.roll(grid, -1, axis=1) down = jnp.roll(grid, -1, axis=0) return -J * (grid * right + grid * down).sum() / (L * L) # --------------------------------------------------------------------------- # VI loss (REINFORCE with per-batch baseline) # --------------------------------------------------------------------------- @eqx.filter_value_and_grad(has_aux=True) def compute_vi_loss( model, token_ids: jax.Array, T: float, ) -> tuple[jax.Array, dict]: """REINFORCE proxy loss for ∇_θ F[q]. The returned scalar is *not* the free energy — it is the REINFORCE surrogate whose gradient equals ∇_θ F/T. Use aux["f"] to track F. Parameters ---------- model : Generator token_ids : int array (batch, seq_len) samples drawn from q_θ T : target temperature Returns ------- (loss, aux), grads aux keys: e, h, f (per spin), |m|, reward_std (variance diagnostic) """ N = gen_config["lattice_size"] ** 2 # ── log q(s) via teacher-forced forward pass ───────────────────────────── # Disable dropout so we get the exact model log-probability. # in_axes=(0, None, None): vmap over batch; broadcast enable_dropout and key. logits = jax.vmap(model, in_axes=(0, None, None))( {"token_ids": token_ids}, False, None ) # (batch, seq_len, state_size) # Σ_t log p(s_t | s_{ jax.Array: """Sample n configurations from the model in fixed-size batches. Returns a jnp int32 array of shape (n, L²). """ all_tokens = [] n_calls = -(-n // _SAMPLE_BATCH) # ceiling division for _ in range(n_calls): key, subkey = jax.random.split(key) all_tokens.append(np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey))) return jnp.asarray(np.concatenate(all_tokens)[:n]) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser( description="Variational inference fine-tuning of an Ising Generator.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("--checkpoint", type=Path, default=None, help="Warm-start from this .eqx file (strongly recommended).") p.add_argument("--output-checkpoint", type=Path, default=None, help="Save final model to this path.") p.add_argument("--num-steps", type=int, default=200) p.add_argument("--batch-size", type=int, default=16, help="Configurations sampled from q per gradient step.") p.add_argument("--learning-rate", type=float, default=1e-4) p.add_argument("--temperature", type=float, default=T_C, help=f"Target Boltzmann temperature (default: T_c ≈ {T_C:.4f}).") p.add_argument("--log-every", type=int, default=1) p.add_argument("--seed", type=int, default=0) return p.parse_args() def main(): args = parse_args() T = args.temperature key = jax.random.PRNGKey(args.seed) # ── Model ───────────────────────────────────────────────────────────────── if args.checkpoint is not None: print(f"Loading checkpoint from {args.checkpoint} …") model = load_checkpoint(args.checkpoint) else: print("Initialising model from scratch.") print(" Tip: use --checkpoint to warm-start from a CE-pretrained model.") key, model_key = jax.random.split(key) model = Generator(config=gen_config, key=model_key) # ── Optimiser ───────────────────────────────────────────────────────────── tx = optax.chain( optax.clip_by_global_norm(1.0), optax.adam(learning_rate=args.learning_rate), ) tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model)) opt_state = tx.init(model) L = gen_config["lattice_size"] print(f"\nVI training | steps={args.num_steps} " f"batch={args.batch_size} T={T:.4f} lr={args.learning_rate} L={L}") print(" columns: e = ⟨E/N⟩ h = −⟨log q⟩/N " "f = e−T·h (minimised) |m| = mean |magnetisation|\n") # ── Training loop ───────────────────────────────────────────────────────── with tqdm(range(args.num_steps), unit="steps") as pbar: for step in pbar: # 1. Sample from current model (bottleneck on CPU) key, sample_key = jax.random.split(key) token_ids = draw_samples(model, args.batch_size, sample_key) # 2. VI gradient step (JIT-compiled teacher-forced forward pass) loss, aux, model, opt_state = vi_step( model, token_ids, opt_state, tx, T ) if step % args.log_every == 0: pbar.set_postfix( e = f"{float(aux['e']):.4f}", h = f"{float(aux['h']):.4f}", f = f"{float(aux['f']):.4f}", m = f"{float(aux['|m|']):.3f}", Rstd= f"{float(aux['reward_std']):.3f}", ) # ── Save ────────────────────────────────────────────────────────────────── if args.output_checkpoint is not None: args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True) eqx.tree_serialise_leaves(args.output_checkpoint, model) print(f"\nSaved checkpoint → {args.output_checkpoint}") if __name__ == "__main__": main()