File size: 12,355 Bytes
5c85f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f90baf
 
 
 
 
 
3b17899
 
 
 
5c85f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
#!/usr/bin/env python
# /// script
# dependencies = [
#   "jax[cuda12]",
#   "equinox",
#   "optax",
#   "tqdm",
#   "jaxtyping",
# ]
# ///
"""Variational inference fine-tuning of the Ising Generator.

Objective
---------
Minimise the variational free energy

    F[q] = E_{s~q}[E(s)] βˆ’ T Β· H[q]
         = T Β· E_{s~q}[ E(s)/T + log q(s) ]

which equals KL(q βˆ₯ p*) up to the constant log Z, where
p*(s) ∝ exp(βˆ’E(s)/T) is the Ising Boltzmann distribution.
As F decreases, the model q approaches the correct physics.

Gradient estimator (REINFORCE / score-function)
------------------------------------------------
    βˆ‡_ΞΈ F = T Β· E_{s~q}[( E(s)/T + log q(s) βˆ’ b ) Β· βˆ‡_ΞΈ log q(s)]

where b = batch-mean reward is a zero-variance control variate.

Per training step
-----------------
  1. Sample a batch of configs from the current model q_ΞΈ         (slow on CPU)
  2. Compute Ising energy E(s) for each sample                    (fast, no model)
  3. Compute log q(s) via a teacher-forced forward pass            (fast, one pass)
  4. Assemble reward R = E/T + log q, subtract baseline, backprop  (fast)

Speed note
----------
Step 1 dominates on CPU (~5 s / sample on a 32Γ—32 lattice).  On GPU it is
typically 10–100Γ— faster and VI training with batch_size β‰₯ 32 is practical.
The --checkpoint flag warm-starts from a CE-pretrained model, which dramati-
cally reduces the number of VI steps needed to converge.

Monitoring
----------
At each step the following quantities are logged:

  e   = ⟨E/N⟩              mean energy per spin  (converges to data ~βˆ’1.45)
  h   = βˆ’βŸ¨log q⟩/N         entropy per spin in nats  (random init β‰ˆ 0.693)
  f   = e βˆ’ TΒ·h             Helmholtz free energy per spin (we minimise this)
  |m| = ⟨|Σs / N|⟩         mean absolute magnetisation
"""

import argparse
from pathlib import Path

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
from tqdm.auto import tqdm

from model import Generator, gen_config, snake_order
from sample import load_checkpoint, sample_batch


# ---------------------------------------------------------------------------
# Physical constants
# ---------------------------------------------------------------------------

J   = 1.0
T_C = 2.0 / np.log(1.0 + np.sqrt(2.0))   # exact 2D Ising T_c β‰ˆ 2.2692


# ---------------------------------------------------------------------------
# Ising energy (pure JAX β€” no model parameters, no gradient needed)
# ---------------------------------------------------------------------------

def ising_energy_per_spin(token_ids: jax.Array) -> jax.Array:
    """Ising energy per spin from a snake-ordered token sequence {0, 1}.

    Hamiltonian: H = βˆ’J Ξ£_{⟨ij⟩} s_i s_j  with periodic boundary conditions.
    Returns a scalar.
    """
    L    = gen_config["lattice_size"]
    rows, cols = snake_order(L)                              # concrete at trace time
    spins = (token_ids * 2 - 1).astype(jnp.float32)        # {0,1} β†’ {βˆ’1,+1}
    grid  = jnp.zeros((L, L)).at[jnp.asarray(rows), jnp.asarray(cols)].set(spins)
    right = jnp.roll(grid, -1, axis=1)
    down  = jnp.roll(grid, -1, axis=0)
    return -J * (grid * right + grid * down).sum() / (L * L)


# ---------------------------------------------------------------------------
# VI loss  (REINFORCE with per-batch baseline)
# ---------------------------------------------------------------------------

@eqx.filter_value_and_grad(has_aux=True)
def compute_vi_loss(
    model,
    token_ids: jax.Array,
    T: float,
) -> tuple[jax.Array, dict]:
    """REINFORCE proxy loss for βˆ‡_ΞΈ F[q].

    The returned scalar is *not* the free energy β€” it is the REINFORCE
    surrogate whose gradient equals βˆ‡_ΞΈ F/T.  Use aux["f"] to track F.

    Parameters
    ----------
    model     : Generator
    token_ids : int array  (batch, seq_len)  samples drawn from q_ΞΈ
    T         : target temperature

    Returns
    -------
    (loss, aux), grads
      aux keys:  e, h, f  (per spin),  |m|,  reward_std  (variance diagnostic)
    """
    N = gen_config["lattice_size"] ** 2

    # ── log q(s) via teacher-forced forward pass ─────────────────────────────
    # Disable dropout so we get the exact model log-probability.
    # in_axes=(0, None, None): vmap over batch; broadcast enable_dropout and key.
    logits = jax.vmap(model, in_axes=(0, None, None))(
        {"token_ids": token_ids}, False, None
    )   # (batch, seq_len, state_size)

    # Ξ£_t log p(s_t | s_{<t})   β€” summed over the sequence axis
    log_q = -optax.softmax_cross_entropy_with_integer_labels(
        logits[:, :-1, :], token_ids[:, 1:]
    ).sum(axis=-1)   # (batch,)

    # ── Ising energies ────────────────────────────────────────────────────────
    # jax.lax.stop_gradient keeps energies out of the autodiff graph.
    energies = jax.lax.stop_gradient(
        jax.vmap(ising_energy_per_spin)(token_ids)   # (batch,)
    )

    # ── REINFORCE reward  R = E/T + log q(s) ─────────────────────────────────
    reward   = jax.lax.stop_gradient(energies / T + log_q)
    baseline = reward.mean()
    # Whiten the reward: subtract mean AND divide by std.
    # This makes the gradient magnitude independent of reward scale and
    # prevents the high-variance collapse where all rewards are near-equal
    # (mode collapse → std→0 → gradient→0 with mean-only baseline).
    reward_norm = (reward - baseline) / (reward.std() + 1e-8)

    # Proxy loss: βˆ‡_ΞΈ loss = E_q[RΜ‚ Β· βˆ‡_ΞΈ log q] = βˆ‡_ΞΈ (F/T)  (up to scale)
    # Minimising this via gradient descent drives ΞΈ toward lower free energy.
    # NOTE: no negation β€” R is a cost to minimise, not a reward to maximise.
    loss = jnp.mean(jax.lax.stop_gradient(reward_norm) * log_q)

    # ── Diagnostics (all stop-gradiented; no effect on training) ─────────────
    e   = energies.mean()                   # mean energy per spin
    h   = -log_q.mean() / N                 # entropy per spin  (nats)
    f   = e - T * h                         # Helmholtz free energy per spin
    m   = jnp.abs((token_ids * 2 - 1).astype(jnp.float32).mean(axis=-1)).mean()
    aux = {
        "e":           e,
        "h":           h,
        "f":           f,
        "|m|":         m,
        "reward_std":  reward.std(),        # REINFORCE variance diagnostic
    }
    return loss, aux


# ---------------------------------------------------------------------------
# Single training step  (JIT-compiled; does NOT include sampling)
# ---------------------------------------------------------------------------

@eqx.filter_jit
def vi_step(
    model,
    token_ids: jax.Array,
    opt_state,
    tx,
    T: float,
):
    """Compute VI loss + gradient and apply one optimiser update.

    Sampling is intentionally excluded so you can profile / replace it
    without re-compiling the gradient computation.
    """
    (loss, aux), grads = compute_vi_loss(model, token_ids, T)
    updates, opt_state = tx.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return loss, aux, model, opt_state


# ---------------------------------------------------------------------------
# Sampling helper
# ---------------------------------------------------------------------------

_SAMPLE_BATCH = 4   # fixed call-site batch; changing triggers recompilation


def draw_samples(model, n: int, key: jax.Array) -> jax.Array:
    """Sample n configurations from the model in fixed-size batches.

    Returns a jnp int32 array of shape (n, LΒ²).
    """
    all_tokens = []
    n_calls = -(-n // _SAMPLE_BATCH)   # ceiling division
    for _ in range(n_calls):
        key, subkey = jax.random.split(key)
        all_tokens.append(np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey)))
    return jnp.asarray(np.concatenate(all_tokens)[:n])


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def parse_args():
    p = argparse.ArgumentParser(
        description="Variational inference fine-tuning of an Ising Generator.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    p.add_argument("--checkpoint",        type=Path, default=None,
                   help="Warm-start from this .eqx file (strongly recommended).")
    p.add_argument("--output-checkpoint", type=Path, default=None,
                   help="Save final model to this path.")
    p.add_argument("--num-steps",         type=int,   default=200)
    p.add_argument("--batch-size",        type=int,   default=16,
                   help="Configurations sampled from q per gradient step.")
    p.add_argument("--learning-rate",     type=float, default=1e-4)
    p.add_argument("--temperature",       type=float, default=T_C,
                   help=f"Target Boltzmann temperature (default: T_c β‰ˆ {T_C:.4f}).")
    p.add_argument("--log-every",         type=int,   default=1)
    p.add_argument("--seed",              type=int,   default=0)
    return p.parse_args()


def main():
    args = parse_args()
    T   = args.temperature
    key = jax.random.PRNGKey(args.seed)

    # ── Model ─────────────────────────────────────────────────────────────────
    if args.checkpoint is not None:
        print(f"Loading checkpoint from {args.checkpoint} …")
        model = load_checkpoint(args.checkpoint)
    else:
        print("Initialising model from scratch.")
        print("  Tip: use --checkpoint to warm-start from a CE-pretrained model.")
        key, model_key = jax.random.split(key)
        model = Generator(config=gen_config, key=model_key)

    # ── Optimiser ─────────────────────────────────────────────────────────────
    tx = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(learning_rate=args.learning_rate),
    )
    tx        = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model))
    opt_state = tx.init(model)

    L = gen_config["lattice_size"]
    print(f"\nVI training  |  steps={args.num_steps}  "
          f"batch={args.batch_size}  T={T:.4f}  lr={args.learning_rate}  L={L}")
    print("  columns:  e = ⟨E/N⟩   h = βˆ’βŸ¨log q⟩/N   "
          "f = eβˆ’TΒ·h  (minimised)   |m| = mean |magnetisation|\n")

    # ── Training loop ─────────────────────────────────────────────────────────
    with tqdm(range(args.num_steps), unit="steps") as pbar:
        for step in pbar:
            # 1. Sample from current model (bottleneck on CPU)
            key, sample_key = jax.random.split(key)
            token_ids = draw_samples(model, args.batch_size, sample_key)

            # 2. VI gradient step (JIT-compiled teacher-forced forward pass)
            loss, aux, model, opt_state = vi_step(
                model, token_ids, opt_state, tx, T
            )

            if step % args.log_every == 0:
                pbar.set_postfix(
                    e   = f"{float(aux['e']):.4f}",
                    h   = f"{float(aux['h']):.4f}",
                    f   = f"{float(aux['f']):.4f}",
                    m   = f"{float(aux['|m|']):.3f}",
                    Rstd= f"{float(aux['reward_std']):.3f}",
                )

    # ── Save ──────────────────────────────────────────────────────────────────
    if args.output_checkpoint is not None:
        args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
        eqx.tree_serialise_leaves(args.output_checkpoint, model)
        print(f"\nSaved checkpoint β†’ {args.output_checkpoint}")


if __name__ == "__main__":
    main()