Update uncertainty.py - fix heteroscedastic loss clamping
Browse files- uncertainty.py +2 -2
uncertainty.py
CHANGED
|
@@ -408,9 +408,9 @@ class HeteroscedasticHead(nn.Module):
|
|
| 408 |
Args:
|
| 409 |
hidden_states: (B, L, d_model)
|
| 410 |
Returns:
|
| 411 |
-
log_var: (B, L, n_outputs) — predicted log-variance per head
|
| 412 |
"""
|
| 413 |
-
return self.log_var_head(hidden_states)
|
| 414 |
|
| 415 |
|
| 416 |
# ============================================================
|
|
|
|
| 408 |
Args:
|
| 409 |
hidden_states: (B, L, d_model)
|
| 410 |
Returns:
|
| 411 |
+
log_var: (B, L, n_outputs) — predicted log-variance per head, clamped to [-5, 5]
|
| 412 |
"""
|
| 413 |
+
return torch.clamp(self.log_var_head(hidden_states), -5.0, 5.0)
|
| 414 |
|
| 415 |
|
| 416 |
# ============================================================
|