Update model.py - fix heteroscedastic loss clamping
Browse files
model.py
CHANGED
|
@@ -727,11 +727,10 @@ class NextStateLoss(nn.Module):
|
|
| 727 |
# --- Heteroscedastic regularization (learned aleatoric uncertainty) ---
|
| 728 |
if 'log_var' in predictions:
|
| 729 |
log_var = predictions['log_var'][:, :-1, :] # (B, L-1, 6)
|
| 730 |
-
#
|
| 731 |
-
|
| 732 |
-
#
|
| 733 |
-
|
| 734 |
-
losses['log_var_reg'] = log_var_penalty
|
| 735 |
|
| 736 |
# Total loss
|
| 737 |
total_loss = sum(losses.values())
|
|
|
|
| 727 |
# --- Heteroscedastic regularization (learned aleatoric uncertainty) ---
|
| 728 |
if 'log_var' in predictions:
|
| 729 |
log_var = predictions['log_var'][:, :-1, :] # (B, L-1, 6)
|
| 730 |
+
# Clamp log_var to prevent collapse: [-5, 5] range
|
| 731 |
+
log_var_clamped = torch.clamp(log_var, -5.0, 5.0)
|
| 732 |
+
# Regularize toward 0 (unit variance prior)
|
| 733 |
+
losses['log_var_reg'] = 0.1 * (log_var_clamped ** 2).mean()
|
|
|
|
| 734 |
|
| 735 |
# Total loss
|
| 736 |
total_loss = sum(losses.values())
|