Spaces:
Paused
Paused
Fix REINFORCE sign: minimize F, not maximize
Browse files- vi_train.py +4 -2
vi_train.py
CHANGED
|
@@ -147,8 +147,10 @@ def compute_vi_loss(
|
|
| 147 |
# (mode collapse β stdβ0 β gradientβ0 with mean-only baseline).
|
| 148 |
reward_norm = (reward - baseline) / (reward.std() + 1e-8)
|
| 149 |
|
| 150 |
-
# Proxy loss: β loss
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# ββ Diagnostics (all stop-gradiented; no effect on training) βββββββββββββ
|
| 154 |
e = energies.mean() # mean energy per spin
|
|
|
|
| 147 |
# (mode collapse β stdβ0 β gradientβ0 with mean-only baseline).
|
| 148 |
reward_norm = (reward - baseline) / (reward.std() + 1e-8)
|
| 149 |
|
| 150 |
+
# Proxy loss: β_ΞΈ loss = E_q[RΜ Β· β_ΞΈ log q] = β_ΞΈ (F/T) (up to scale)
|
| 151 |
+
# Minimising this via gradient descent drives ΞΈ toward lower free energy.
|
| 152 |
+
# NOTE: no negation β R is a cost to minimise, not a reward to maximise.
|
| 153 |
+
loss = jnp.mean(jax.lax.stop_gradient(reward_norm) * log_q)
|
| 154 |
|
| 155 |
# ββ Diagnostics (all stop-gradiented; no effect on training) βββββββββββββ
|
| 156 |
e = energies.mean() # mean energy per spin
|