| """MDLM ELBO loss with SUBS parameterisation. |
| |
| Ported from the Craftax JAX implementation (src/diffusion/loss.py). |
| Computes continuous-time loss on masked positions only, with analytic |
| SUBS weighting clipped for numerical stability. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Callable |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from src.diffusion.schedules import alpha_prime |
|
|
|
|
| _MAX_WEIGHT: float = 1000.0 |
|
|
|
|
| def mdlm_loss( |
| logits: Tensor, |
| x0: Tensor, |
| zt: Tensor, |
| t: Tensor, |
| mask_token: int, |
| pad_token: int, |
| schedule_fn: Callable[[Tensor], Tensor], |
| weight_clip: float = _MAX_WEIGHT, |
| label_smoothing: float = 0.0, |
| use_importance_weighting: bool = False, |
| ) -> Tensor: |
| """Compute masked diffusion loss. |
| |
| By default uses a simple masked cross-entropy average (matching the |
| reference implementation). When ``use_importance_weighting=True``, |
| applies SUBS weighting ``w(t) = -alpha'(t) / (1 - alpha_t)``. |
| |
| Args: |
| logits: Model output. Shape ``[B, L, vocab]``. |
| x0: Clean action sequences. Shape ``[B, L]``, int64. |
| zt: Noisy sequences. Shape ``[B, L]``, int64. |
| t: Per-sample diffusion time in [0, 1]. Shape ``[B]``. |
| mask_token: MASK token ID. |
| pad_token: PAD token ID. |
| schedule_fn: Noise schedule returning alpha(t). |
| weight_clip: Upper clamp for SUBS weight (default 1000). |
| label_smoothing: Smoothing epsilon for cross-entropy. |
| use_importance_weighting: If ``True``, apply SUBS w(t) per sample. |
| |
| Returns: |
| Scalar loss. Returns ``0.0`` when no masked positions exist. |
| """ |
| B, L, V = logits.shape |
|
|
| |
| is_masked = (zt == mask_token) & (x0 != pad_token) |
|
|
| if not is_masked.any(): |
| return logits.new_tensor(0.0) |
|
|
| |
| |
| |
| safe_targets = x0.clamp(0, V - 1) |
| ce = F.cross_entropy( |
| logits.reshape(-1, V), |
| safe_targets.reshape(-1), |
| reduction="none", |
| label_smoothing=label_smoothing, |
| ) |
| ce = ce.reshape(B, L) |
|
|
| |
| ce = ce * is_masked.float() |
|
|
| |
| n_masked_total = is_masked.float().sum().clamp(min=1.0) |
| loss = ce.sum() / n_masked_total |
|
|
| if use_importance_weighting: |
| |
| alpha_t = schedule_fn(t) |
| d_alpha = alpha_prime(t, schedule_fn) |
| w_t = (-d_alpha) / (1.0 - alpha_t + 1e-8) |
| w_t = w_t.clamp(0.0, weight_clip) |
|
|
| |
| n_masked_per = is_masked.float().sum(dim=1).clamp(min=1.0) |
| per_sample = ce.sum(dim=1) / n_masked_per |
| loss = (per_sample * w_t).mean() |
|
|
| return loss |
|
|
|
|
| def auxiliary_goal_loss( |
| goal_pred: Tensor, |
| global_obs: Tensor, |
| pad_value: float = -1.0, |
| ) -> Tensor: |
| """MSE loss for auxiliary staircase-coordinate prediction. |
| |
| Args: |
| goal_pred: Predicted normalised staircase coords. Shape ``[B, 2]``. |
| global_obs: Full map glyphs. Shape ``[B, 21, 79]``, int. |
| pad_value: Coordinate value used when staircase is not visible. |
| |
| Returns: |
| Scalar MSE loss over samples where the staircase is visible. |
| Returns ``0.0`` when no staircase is visible in the batch. |
| """ |
| targets = find_staircase_from_glyphs(global_obs) |
| targets = targets.to(goal_pred.device, dtype=goal_pred.dtype) |
|
|
| |
| valid = (targets[:, 0] != pad_value) |
| if not valid.any(): |
| return goal_pred.new_tensor(0.0) |
|
|
| diff = (goal_pred[valid] - targets[valid]) ** 2 |
| return diff.mean() |
|
|
|
|
| def find_staircase_from_glyphs(global_obs: Tensor) -> Tensor: |
| """Locate the staircase '>' in the global glyph map. |
| |
| Searches for NLE staircase-down glyph (character code 62 = '>'). |
| Returns normalised (row/H, col/W) coordinates per batch element, |
| or (-1, -1) when the staircase is not visible. |
| |
| Args: |
| global_obs: Glyph map. Shape ``[B, H, W]`` or ``[H, W]``, int. |
| |
| Returns: |
| Normalised coordinates. Shape ``[B, 2]`` (float32). |
| """ |
| if global_obs.ndim == 2: |
| global_obs = global_obs.unsqueeze(0) |
|
|
| B, H, W = global_obs.shape |
| |
| |
| is_stair = ( |
| (global_obs == 62) |
| | (global_obs == 2310) |
| | (global_obs == 2368) |
| | (global_obs == 2383) |
| ) |
|
|
| coords = torch.full( |
| (B, 2), -1.0, dtype=torch.float32, device=global_obs.device |
| ) |
| for b in range(B): |
| positions = is_stair[b].nonzero(as_tuple=False) |
| if positions.shape[0] > 0: |
| row = positions[0, 0].float() / max(1, H - 1) |
| col = positions[0, 1].float() / max(1, W - 1) |
| coords[b, 0] = row |
| coords[b, 1] = col |
|
|
| return coords |
|
|