Spaces:
Sleeping
Sleeping
File size: 12,355 Bytes
5c85f22 0f90baf 3b17899 5c85f22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | #!/usr/bin/env python
# /// script
# dependencies = [
# "jax[cuda12]",
# "equinox",
# "optax",
# "tqdm",
# "jaxtyping",
# ]
# ///
"""Variational inference fine-tuning of the Ising Generator.
Objective
---------
Minimise the variational free energy
F[q] = E_{s~q}[E(s)] β T Β· H[q]
= T Β· E_{s~q}[ E(s)/T + log q(s) ]
which equals KL(q β₯ p*) up to the constant log Z, where
p*(s) β exp(βE(s)/T) is the Ising Boltzmann distribution.
As F decreases, the model q approaches the correct physics.
Gradient estimator (REINFORCE / score-function)
------------------------------------------------
β_ΞΈ F = T Β· E_{s~q}[( E(s)/T + log q(s) β b ) Β· β_ΞΈ log q(s)]
where b = batch-mean reward is a zero-variance control variate.
Per training step
-----------------
1. Sample a batch of configs from the current model q_ΞΈ (slow on CPU)
2. Compute Ising energy E(s) for each sample (fast, no model)
3. Compute log q(s) via a teacher-forced forward pass (fast, one pass)
4. Assemble reward R = E/T + log q, subtract baseline, backprop (fast)
Speed note
----------
Step 1 dominates on CPU (~5 s / sample on a 32Γ32 lattice). On GPU it is
typically 10β100Γ faster and VI training with batch_size β₯ 32 is practical.
The --checkpoint flag warm-starts from a CE-pretrained model, which dramati-
cally reduces the number of VI steps needed to converge.
Monitoring
----------
At each step the following quantities are logged:
e = β¨E/Nβ© mean energy per spin (converges to data ~β1.45)
h = ββ¨log qβ©/N entropy per spin in nats (random init β 0.693)
f = e β TΒ·h Helmholtz free energy per spin (we minimise this)
|m| = β¨|Ξ£s / N|β© mean absolute magnetisation
"""
import argparse
from pathlib import Path
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
from tqdm.auto import tqdm
from model import Generator, gen_config, snake_order
from sample import load_checkpoint, sample_batch
# ---------------------------------------------------------------------------
# Physical constants
# ---------------------------------------------------------------------------
J = 1.0
T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact 2D Ising T_c β 2.2692
# ---------------------------------------------------------------------------
# Ising energy (pure JAX β no model parameters, no gradient needed)
# ---------------------------------------------------------------------------
def ising_energy_per_spin(token_ids: jax.Array) -> jax.Array:
"""Ising energy per spin from a snake-ordered token sequence {0, 1}.
Hamiltonian: H = βJ Ξ£_{β¨ijβ©} s_i s_j with periodic boundary conditions.
Returns a scalar.
"""
L = gen_config["lattice_size"]
rows, cols = snake_order(L) # concrete at trace time
spins = (token_ids * 2 - 1).astype(jnp.float32) # {0,1} β {β1,+1}
grid = jnp.zeros((L, L)).at[jnp.asarray(rows), jnp.asarray(cols)].set(spins)
right = jnp.roll(grid, -1, axis=1)
down = jnp.roll(grid, -1, axis=0)
return -J * (grid * right + grid * down).sum() / (L * L)
# ---------------------------------------------------------------------------
# VI loss (REINFORCE with per-batch baseline)
# ---------------------------------------------------------------------------
@eqx.filter_value_and_grad(has_aux=True)
def compute_vi_loss(
model,
token_ids: jax.Array,
T: float,
) -> tuple[jax.Array, dict]:
"""REINFORCE proxy loss for β_ΞΈ F[q].
The returned scalar is *not* the free energy β it is the REINFORCE
surrogate whose gradient equals β_ΞΈ F/T. Use aux["f"] to track F.
Parameters
----------
model : Generator
token_ids : int array (batch, seq_len) samples drawn from q_ΞΈ
T : target temperature
Returns
-------
(loss, aux), grads
aux keys: e, h, f (per spin), |m|, reward_std (variance diagnostic)
"""
N = gen_config["lattice_size"] ** 2
# ββ log q(s) via teacher-forced forward pass βββββββββββββββββββββββββββββ
# Disable dropout so we get the exact model log-probability.
# in_axes=(0, None, None): vmap over batch; broadcast enable_dropout and key.
logits = jax.vmap(model, in_axes=(0, None, None))(
{"token_ids": token_ids}, False, None
) # (batch, seq_len, state_size)
# Ξ£_t log p(s_t | s_{<t}) β summed over the sequence axis
log_q = -optax.softmax_cross_entropy_with_integer_labels(
logits[:, :-1, :], token_ids[:, 1:]
).sum(axis=-1) # (batch,)
# ββ Ising energies ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# jax.lax.stop_gradient keeps energies out of the autodiff graph.
energies = jax.lax.stop_gradient(
jax.vmap(ising_energy_per_spin)(token_ids) # (batch,)
)
# ββ REINFORCE reward R = E/T + log q(s) βββββββββββββββββββββββββββββββββ
reward = jax.lax.stop_gradient(energies / T + log_q)
baseline = reward.mean()
# Whiten the reward: subtract mean AND divide by std.
# This makes the gradient magnitude independent of reward scale and
# prevents the high-variance collapse where all rewards are near-equal
# (mode collapse β stdβ0 β gradientβ0 with mean-only baseline).
reward_norm = (reward - baseline) / (reward.std() + 1e-8)
# Proxy loss: β_ΞΈ loss = E_q[RΜ Β· β_ΞΈ log q] = β_ΞΈ (F/T) (up to scale)
# Minimising this via gradient descent drives ΞΈ toward lower free energy.
# NOTE: no negation β R is a cost to minimise, not a reward to maximise.
loss = jnp.mean(jax.lax.stop_gradient(reward_norm) * log_q)
# ββ Diagnostics (all stop-gradiented; no effect on training) βββββββββββββ
e = energies.mean() # mean energy per spin
h = -log_q.mean() / N # entropy per spin (nats)
f = e - T * h # Helmholtz free energy per spin
m = jnp.abs((token_ids * 2 - 1).astype(jnp.float32).mean(axis=-1)).mean()
aux = {
"e": e,
"h": h,
"f": f,
"|m|": m,
"reward_std": reward.std(), # REINFORCE variance diagnostic
}
return loss, aux
# ---------------------------------------------------------------------------
# Single training step (JIT-compiled; does NOT include sampling)
# ---------------------------------------------------------------------------
@eqx.filter_jit
def vi_step(
model,
token_ids: jax.Array,
opt_state,
tx,
T: float,
):
"""Compute VI loss + gradient and apply one optimiser update.
Sampling is intentionally excluded so you can profile / replace it
without re-compiling the gradient computation.
"""
(loss, aux), grads = compute_vi_loss(model, token_ids, T)
updates, opt_state = tx.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return loss, aux, model, opt_state
# ---------------------------------------------------------------------------
# Sampling helper
# ---------------------------------------------------------------------------
_SAMPLE_BATCH = 4 # fixed call-site batch; changing triggers recompilation
def draw_samples(model, n: int, key: jax.Array) -> jax.Array:
"""Sample n configurations from the model in fixed-size batches.
Returns a jnp int32 array of shape (n, LΒ²).
"""
all_tokens = []
n_calls = -(-n // _SAMPLE_BATCH) # ceiling division
for _ in range(n_calls):
key, subkey = jax.random.split(key)
all_tokens.append(np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey)))
return jnp.asarray(np.concatenate(all_tokens)[:n])
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(
description="Variational inference fine-tuning of an Ising Generator.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument("--checkpoint", type=Path, default=None,
help="Warm-start from this .eqx file (strongly recommended).")
p.add_argument("--output-checkpoint", type=Path, default=None,
help="Save final model to this path.")
p.add_argument("--num-steps", type=int, default=200)
p.add_argument("--batch-size", type=int, default=16,
help="Configurations sampled from q per gradient step.")
p.add_argument("--learning-rate", type=float, default=1e-4)
p.add_argument("--temperature", type=float, default=T_C,
help=f"Target Boltzmann temperature (default: T_c β {T_C:.4f}).")
p.add_argument("--log-every", type=int, default=1)
p.add_argument("--seed", type=int, default=0)
return p.parse_args()
def main():
args = parse_args()
T = args.temperature
key = jax.random.PRNGKey(args.seed)
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if args.checkpoint is not None:
print(f"Loading checkpoint from {args.checkpoint} β¦")
model = load_checkpoint(args.checkpoint)
else:
print("Initialising model from scratch.")
print(" Tip: use --checkpoint to warm-start from a CE-pretrained model.")
key, model_key = jax.random.split(key)
model = Generator(config=gen_config, key=model_key)
# ββ Optimiser βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
tx = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=args.learning_rate),
)
tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model))
opt_state = tx.init(model)
L = gen_config["lattice_size"]
print(f"\nVI training | steps={args.num_steps} "
f"batch={args.batch_size} T={T:.4f} lr={args.learning_rate} L={L}")
print(" columns: e = β¨E/Nβ© h = ββ¨log qβ©/N "
"f = eβTΒ·h (minimised) |m| = mean |magnetisation|\n")
# ββ Training loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with tqdm(range(args.num_steps), unit="steps") as pbar:
for step in pbar:
# 1. Sample from current model (bottleneck on CPU)
key, sample_key = jax.random.split(key)
token_ids = draw_samples(model, args.batch_size, sample_key)
# 2. VI gradient step (JIT-compiled teacher-forced forward pass)
loss, aux, model, opt_state = vi_step(
model, token_ids, opt_state, tx, T
)
if step % args.log_every == 0:
pbar.set_postfix(
e = f"{float(aux['e']):.4f}",
h = f"{float(aux['h']):.4f}",
f = f"{float(aux['f']):.4f}",
m = f"{float(aux['|m|']):.3f}",
Rstd= f"{float(aux['reward_std']):.3f}",
)
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if args.output_checkpoint is not None:
args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
eqx.tree_serialise_leaves(args.output_checkpoint, model)
print(f"\nSaved checkpoint β {args.output_checkpoint}")
if __name__ == "__main__":
main()
|