| """Forward masking process q(z_t | x_0). |
| |
| Ported from the Craftax JAX implementation (src/diffusion/forward.py). |
| Each token is independently replaced with mask_token with probability |
| sigma_t = 1 - alpha_t. PAD positions are never masked. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Callable |
|
|
| import torch |
| from torch import Tensor |
|
|
|
|
| def q_sample( |
| x0: Tensor, |
| t: Tensor, |
| mask_token: int, |
| pad_token: int, |
| schedule_fn: Callable[[Tensor], Tensor], |
| ) -> Tensor: |
| """Sample z_t from the forward masking process. |
| |
| Args: |
| x0: Clean action sequences. Shape ``[B, L]``, dtype int64. |
| t: Per-sample diffusion time in [0, 1]. Shape ``[B]``. |
| mask_token: Integer ID of the MASK token. |
| pad_token: Integer ID of the PAD token. |
| schedule_fn: Noise schedule returning alpha(t). |
| |
| Returns: |
| Noisy sequence z_t. Shape ``[B, L]``, dtype int64. |
| PAD positions are preserved unchanged. |
| """ |
| alpha_t = schedule_fn(t) |
| sigma_t = 1.0 - alpha_t |
| sigma_t = sigma_t.unsqueeze(-1) |
|
|
| |
| mask_draws = torch.rand_like(x0, dtype=torch.float32) |
| do_mask = mask_draws < sigma_t |
|
|
| zt = torch.where(do_mask, mask_token, x0) |
|
|
| |
| pad_mask = x0 == pad_token |
| zt = torch.where(pad_mask, pad_token, zt) |
|
|
| return zt |
|
|