Jdice27 commited on
Commit
a67d720
·
verified ·
1 Parent(s): a7372d1

Update uncertainty.py - fix heteroscedastic loss clamping

Browse files
Files changed (1) hide show
  1. uncertainty.py +2 -2
uncertainty.py CHANGED
@@ -408,9 +408,9 @@ class HeteroscedasticHead(nn.Module):
408
  Args:
409
  hidden_states: (B, L, d_model)
410
  Returns:
411
- log_var: (B, L, n_outputs) — predicted log-variance per head
412
  """
413
- return self.log_var_head(hidden_states)
414
 
415
 
416
  # ============================================================
 
408
  Args:
409
  hidden_states: (B, L, d_model)
410
  Returns:
411
+ log_var: (B, L, n_outputs) — predicted log-variance per head, clamped to [-5, 5]
412
  """
413
+ return torch.clamp(self.log_var_head(hidden_states), -5.0, 5.0)
414
 
415
 
416
  # ============================================================