Jdice27 commited on
Commit
a7372d1
·
verified ·
1 Parent(s): e43dca4

Update model.py - fix heteroscedastic loss clamping

Browse files
Files changed (1) hide show
  1. model.py +4 -5
model.py CHANGED
@@ -727,11 +727,10 @@ class NextStateLoss(nn.Module):
727
  # --- Heteroscedastic regularization (learned aleatoric uncertainty) ---
728
  if 'log_var' in predictions:
729
  log_var = predictions['log_var'][:, :-1, :] # (B, L-1, 6)
730
- # Regularize: penalize overly high uncertainty (prevent collapse)
731
- # The individual heads already implicitly learn to attend to uncertainty
732
- # via the gradient signal, but we add a mild KL-like penalty
733
- log_var_penalty = 0.01 * log_var.mean()
734
- losses['log_var_reg'] = log_var_penalty
735
 
736
  # Total loss
737
  total_loss = sum(losses.values())
 
727
  # --- Heteroscedastic regularization (learned aleatoric uncertainty) ---
728
  if 'log_var' in predictions:
729
  log_var = predictions['log_var'][:, :-1, :] # (B, L-1, 6)
730
+ # Clamp log_var to prevent collapse: [-5, 5] range
731
+ log_var_clamped = torch.clamp(log_var, -5.0, 5.0)
732
+ # Regularize toward 0 (unit variance prior)
733
+ losses['log_var_reg'] = 0.1 * (log_var_clamped ** 2).mean()
 
734
 
735
  # Total loss
736
  total_loss = sum(losses.values())