bertran-yorro commited on
Commit
3b17899
Β·
verified Β·
1 Parent(s): 2c223e2

Fix REINFORCE sign: minimize F, not maximize

Browse files
Files changed (1) hide show
  1. vi_train.py +4 -2
vi_train.py CHANGED
@@ -147,8 +147,10 @@ def compute_vi_loss(
147
  # (mode collapse β†’ stdβ†’0 β†’ gradientβ†’0 with mean-only baseline).
148
  reward_norm = (reward - baseline) / (reward.std() + 1e-8)
149
 
150
- # Proxy loss: βˆ‡ loss = βˆ’E_q[ RΜ‚ Β· βˆ‡ log q ] = βˆ‡ F/T (up to scale)
151
- loss = jnp.mean(jax.lax.stop_gradient(reward_norm) * (-log_q))
 
 
152
 
153
  # ── Diagnostics (all stop-gradiented; no effect on training) ─────────────
154
  e = energies.mean() # mean energy per spin
 
147
  # (mode collapse β†’ stdβ†’0 β†’ gradientβ†’0 with mean-only baseline).
148
  reward_norm = (reward - baseline) / (reward.std() + 1e-8)
149
 
150
+ # Proxy loss: βˆ‡_ΞΈ loss = E_q[RΜ‚ Β· βˆ‡_ΞΈ log q] = βˆ‡_ΞΈ (F/T) (up to scale)
151
+ # Minimising this via gradient descent drives ΞΈ toward lower free energy.
152
+ # NOTE: no negation β€” R is a cost to minimise, not a reward to maximise.
153
+ loss = jnp.mean(jax.lax.stop_gradient(reward_norm) * log_q)
154
 
155
  # ── Diagnostics (all stop-gradiented; no effect on training) ─────────────
156
  e = energies.mean() # mean energy per spin