File size: 14,114 Bytes
f748552 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 | """ReMDM reverse denoising with remasking strategies.
Ported from the Craftax JAX implementation (src/diffusion/sampling.py).
Implements MaskGIT-style progressive unmasking with optional stochastic
remasking (ReMDM) using three strategy variants.
"""
from __future__ import annotations
from types import SimpleNamespace
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from src.diffusion.schedules import get_schedule
# NLE hazard glyph IDs and char codes (walls, locked doors, lava, water)
_HAZARD_GLYPHS: frozenset[int] = frozenset({2359, 2360, 2389, 2390})
_HAZARD_CHARS: frozenset[int] = frozenset(
{ord("|"), ord("-"), ord("+"), ord("L"), ord("W")}
)
# Cardinal action → (dy, dx) offsets
_CARDINAL_OFFSETS: dict[int, tuple[int, int]] = {
0: (-1, 0), 1: (0, 1), 2: (1, 0), 3: (0, -1),
}
_N_PHYSICS_CHECK = 8 # only inspect the first N plan positions
def _check_hazard(local_crop: np.ndarray, action: int) -> bool:
"""Return True if *action* from the agent's centre steps into a hazard.
Args:
local_crop: ``[crop_size, crop_size]`` glyph array.
action: Cardinal action index (0=N, 1=E, 2=S, 3=W).
Returns:
``True`` when the target cell contains a hazard glyph.
"""
if action not in _CARDINAL_OFFSETS:
return False
cs = local_crop.shape[0]
cy, cx = cs // 2, cs // 2
dy, dx = _CARDINAL_OFFSETS[action]
ny, nx = cy + dy, cx + dx
if not (0 <= ny < cs and 0 <= nx < cs):
return True
glyph = int(local_crop[ny, nx])
return glyph in _HAZARD_GLYPHS or glyph in _HAZARD_CHARS
def top_k_filter(logits: Tensor, k: int) -> Tensor:
"""Zero out all but the top-k logits per position.
Args:
logits: Raw logits. Shape ``[..., V]``.
k: Number of top entries to keep.
Returns:
Filtered logits with non-top-k set to ``-inf``.
"""
if k <= 0 or k >= logits.shape[-1]:
return logits
topk_vals, _ = logits.topk(k, dim=-1) # [..., k]
threshold = topk_vals[..., -1:] # [..., 1]
return logits.masked_fill(logits < threshold, float("-inf"))
def _compute_remask_prob(
strategy: str,
eta: float,
sigma_max: float,
confidence: Tensor | None,
) -> Tensor | float:
"""Compute per-token remasking probability.
Args:
strategy: One of ``"rescale"``, ``"cap"``, ``"conf"``.
eta: Base remasking strength hyperparameter.
sigma_max: ``1 - alpha_t(ratio)`` at current step.
confidence: Per-token confidence scores. Shape ``[B, L]``.
Required only for the ``"conf"`` strategy.
Returns:
Scalar or ``[B, L]`` tensor of remasking probabilities.
"""
if strategy == "rescale":
return eta * sigma_max
if strategy == "cap":
return min(eta, sigma_max)
if strategy == "conf":
assert confidence is not None, "conf strategy requires confidence"
return eta * sigma_max * (1.0 - confidence)
raise ValueError(f"Unknown remask strategy: {strategy}")
@torch.no_grad()
def remdm_sample(
model: torch.nn.Module,
local_obs: Tensor,
global_obs: Tensor,
cfg: SimpleNamespace,
device: torch.device | str,
physics_aware: bool = True,
blind_global: bool = False,
return_analytics: bool = False,
num_steps: int | None = None,
) -> Tensor | tuple[Tensor, list, list[float], list[int]]:
"""Generate action sequences via iterative ReMDM denoising.
Args:
model: Denoising model with forward signature
``(local_obs, global_obs, action_seq, t_discrete) -> dict``.
local_obs: Local crop observations. Shape ``[B, 9, 9]``.
global_obs: Global map observations. Shape ``[B, 21, 79]``.
cfg: Config namespace with ``seq_len``, ``mask_token``,
``action_dim``, ``diffusion_steps_eval``, ``temperature``,
``top_k``, ``eta``, ``remask_strategy``, ``noise_schedule``.
device: Torch device.
physics_aware: If ``True``, soft-penalise hazardous cardinal actions
by overriding their confidence to ``0.001`` before commitment
ranking. Only checks the first ``_N_PHYSICS_CHECK`` positions.
blind_global: If ``True``, zero out the global map observation
(local-only ablation).
return_analytics: If ``True``, also return per-step analytics as
``(seq, path_per_step, tracking_confidence, tracking_masked)``.
num_steps: Override number of denoising steps (default uses
``cfg.diffusion_steps_eval``).
Returns:
When ``return_analytics=False`` (default): fully committed action
sequence of shape ``[B, seq_len]``, int64, with no MASK tokens.
When ``return_analytics=True``: tuple
``(seq, path_per_step, tracking_confidence, tracking_masked_count)``
where ``path_per_step`` is a list of ``[seq_len]`` numpy arrays,
``tracking_confidence`` a list of per-step avg unmasked confidence
floats, and ``tracking_masked_count`` a list of masked-token counts.
"""
B = local_obs.shape[0]
seq_len = cfg.seq_len
mask_token = cfg.mask_token
action_dim = cfg.action_dim
K = num_steps if num_steps is not None else cfg.diffusion_steps_eval
schedule_fn = get_schedule(cfg.noise_schedule)
min_keep = max(1, int(seq_len * 0.10)) # Safety Net: always unmask ≥10%
local_obs = local_obs.to(device)
global_obs = global_obs.to(device)
if blind_global:
global_obs = torch.zeros_like(global_obs)
# Pre-compute numpy local crops for physics checks (CPU, batch loop)
local_np: np.ndarray | None = None # [B, crop, crop]
if physics_aware:
local_np = local_obs.cpu().numpy()
# Analytics buffers (only populated when return_analytics=True)
path_per_step: list[np.ndarray] = []
tracking_confidence: list[float] = []
tracking_masked_count: list[int] = []
# Start fully masked
seq = torch.full(
(B, seq_len), mask_token, dtype=torch.long, device=device
)
for k in range(1, K + 1):
ratio = k / K
# Pass as tensor (not Python int) to avoid torch.compile recompilation
t_discrete = torch.full(
(B,), int(cfg.num_diffusion_steps * (1.0 - ratio)),
dtype=torch.long, device=device,
)
# Forward pass
out = model(local_obs, global_obs, seq, t_discrete)
logits = out["actions"] # [B, seq_len, vocab]
# Mask invalid action tokens (indices >= action_dim)
logits[:, :, action_dim:] = float("-inf")
# Temperature scaling
logits = logits / cfg.temperature
# Top-K filtering
logits = top_k_filter(logits, cfg.top_k)
# Sample predictions
probs = F.softmax(logits, dim=-1) # [B, seq_len, action_dim]
preds = Categorical(probs=probs).sample() # [B, seq_len]
# Confidence: probability of the sampled token
conf = probs.gather(
-1, preds.unsqueeze(-1)
).squeeze(-1) # [B, seq_len]
# Physics softener: demote hazardous cardinal actions to conf=0.001
if physics_aware and local_np is not None:
preds_np = preds.cpu().numpy() # [B, seq_len]
conf_override = conf.clone()
for b in range(B):
crop_b = np.asarray(local_np[b]) # [crop, crop]
for pos in range(min(_N_PHYSICS_CHECK, seq_len)):
action = int(preds_np[b, pos])
if _check_hazard(crop_b, action):
conf_override[b, pos] = 0.001
conf = conf_override
is_masked = seq == mask_token # [B, seq_len]
if k < K:
# MaskGIT progressive unmasking with min-keep guarantee
n_unmask = max(min_keep, max(1, int(seq_len * ratio)))
# Set confidence of non-masked positions to -1 so they
# are not selected for unmasking
unmask_scores = conf.clone()
unmask_scores[~is_masked] = -1.0
# For each batch element, unmask top-confidence masked positions
_, topk_indices = unmask_scores.topk(
n_unmask, dim=-1
) # [B, n_unmask]
# Build scatter mask for positions to unmask
unmask_mask = torch.zeros_like(seq, dtype=torch.bool)
unmask_mask.scatter_(1, topk_indices, True)
unmask_mask = unmask_mask & is_masked # only unmask masked pos
seq = torch.where(unmask_mask, preds, seq)
# ReMDM stochastic remasking of committed (non-masked) positions
is_committed = seq != mask_token # [B, seq_len]
alpha_t_ratio = schedule_fn(
torch.tensor(ratio, device=device)
)
sigma_max = (1.0 - alpha_t_ratio).item()
remask_prob = _compute_remask_prob(
cfg.remask_strategy, cfg.eta, sigma_max, conf
)
if isinstance(remask_prob, Tensor):
do_remask = (
torch.rand_like(conf) < remask_prob
) & is_committed
else:
do_remask = (
torch.rand(B, seq_len, device=device) < remask_prob
) & is_committed
seq = torch.where(do_remask, mask_token, seq)
else:
# Final step: commit all remaining MASK tokens
seq = torch.where(is_masked, preds, seq)
# Analytics tracking
if return_analytics:
path_per_step.append(seq[0].cpu().numpy().copy())
still_masked = (seq[0] == mask_token)
unmasked_conf = conf[0][~still_masked]
avg_conf = (
unmasked_conf.mean().item()
if unmasked_conf.numel() > 0 else 0.0
)
tracking_confidence.append(avg_conf)
tracking_masked_count.append(int(still_masked.sum().item()))
assert (seq != mask_token).all(), (
"remdm_sample produced MASK tokens in final output"
)
if return_analytics:
return seq, path_per_step, tracking_confidence, tracking_masked_count
return seq
@torch.no_grad()
def greedy_sample(
model: torch.nn.Module,
local_obs: Tensor,
global_obs: Tensor,
cfg: SimpleNamespace,
device: torch.device | str,
blind_global: bool = False,
num_steps: int | None = None,
) -> Tensor:
"""Greedy (argmax) MaskGIT sampling — no temperature, top-K, or remasking.
Used by ``DataCollector`` during DAgger for deterministic rollouts,
matching the reference ``run_model_episode`` behaviour.
Args:
model: Denoising model.
local_obs: Shape ``[B, 9, 9]``.
global_obs: Shape ``[B, 21, 79]``.
cfg: Config namespace.
device: Torch device.
blind_global: Zero out global map (local-only ablation).
Returns:
Fully committed action sequence ``[B, seq_len]``, int64.
"""
B = local_obs.shape[0]
seq_len = cfg.seq_len
mask_token = cfg.mask_token
action_dim = cfg.action_dim
K = num_steps if num_steps is not None else cfg.diffusion_steps_eval
local_obs = local_obs.to(device)
global_obs = global_obs.to(device)
if blind_global:
global_obs = torch.zeros_like(global_obs)
seq = torch.full(
(B, seq_len), mask_token, dtype=torch.long, device=device,
)
for k in range(1, K + 1):
ratio = k / K
t_discrete = torch.full(
(B,), int(cfg.num_diffusion_steps * (1.0 - ratio)),
dtype=torch.long, device=device,
)
out = model(local_obs, global_obs, seq, t_discrete)
logits = out["actions"] # [B, seq_len, vocab]
# Mask invalid action tokens
logits[:, :, action_dim:] = float("-inf")
# Greedy: argmax over softmax (no temperature, no top-K)
probs = F.softmax(logits, dim=-1) # [B, seq_len, action_dim]
confidences, preds = probs.max(dim=-1) # [B, seq_len] each
# MaskGIT progressive unmasking by confidence
num_to_unmask = max(1, int(seq_len * ratio))
is_masked = seq == mask_token # [B, seq_len]
# Score only masked positions for unmasking
scores = confidences.clone()
scores[~is_masked] = -1.0
_, topk_idx = scores.topk(num_to_unmask, dim=-1)
unmask_mask = torch.zeros_like(seq, dtype=torch.bool)
unmask_mask.scatter_(1, topk_idx, True)
unmask_mask = unmask_mask & is_masked
seq = torch.where(unmask_mask, preds, seq)
# No remasking in greedy mode
# Force-commit any remaining masked tokens
still_masked = seq == mask_token
if still_masked.any():
t_zero = torch.zeros(B, dtype=torch.long, device=device)
out = model(local_obs, global_obs, seq, t_zero)
logits = out["actions"]
logits[:, :, action_dim:] = float("-inf")
preds = logits.argmax(dim=-1)
seq = torch.where(still_masked, preds, seq)
return seq
def select_action(
model: torch.nn.Module,
local_obs: Tensor,
global_obs: Tensor,
cfg: SimpleNamespace,
device: torch.device | str,
physics_aware: bool = True,
blind_global: bool = False,
) -> int:
"""Sample a single action from a length-1 batch.
Args:
model: Denoising model.
local_obs: Shape ``[9, 9]`` or ``[1, 9, 9]``.
global_obs: Shape ``[21, 79]`` or ``[1, 21, 79]``.
cfg: Config namespace.
device: Torch device.
physics_aware: Forward to ``remdm_sample``.
blind_global: Forward to ``remdm_sample``.
Returns:
The first action of the generated plan (int).
"""
if local_obs.ndim == 2:
local_obs = local_obs.unsqueeze(0)
if global_obs.ndim == 2:
global_obs = global_obs.unsqueeze(0)
seq = remdm_sample(
model, local_obs, global_obs, cfg, device,
physics_aware=physics_aware, blind_global=blind_global,
)
return seq[0, 0].item()
|