from dataclasses import replace from typing import Protocol import torch from ltx_core.types import LatentState class Noiser(Protocol): """Protocol for adding noise to a latent state during diffusion.""" def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ... class GaussianNoiser(Noiser): """Adds Gaussian noise to a latent state, scaled by the denoise mask.""" def __init__(self, generator: torch.Generator): super().__init__() self.generator = generator def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState: noise = torch.randn( *latent_state.latent.shape, device=latent_state.latent.device, dtype=latent_state.latent.dtype, generator=self.generator, ) scaled_mask = latent_state.denoise_mask * noise_scale latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask) return replace( latent_state, latent=latent.to(latent_state.latent.dtype), )