Dramabox / ltx2 /ltx_core /components /noisers.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
from dataclasses import replace
from typing import Protocol
import torch
from ltx_core.types import LatentState
class Noiser(Protocol):
"""Protocol for adding noise to a latent state during diffusion."""
def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
class GaussianNoiser(Noiser):
"""Adds Gaussian noise to a latent state, scaled by the denoise mask."""
def __init__(self, generator: torch.Generator):
super().__init__()
self.generator = generator
def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
noise = torch.randn(
*latent_state.latent.shape,
device=latent_state.latent.device,
dtype=latent_state.latent.dtype,
generator=self.generator,
)
scaled_mask = latent_state.denoise_mask * noise_scale
latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
return replace(
latent_state,
latent=latent.to(latent_state.latent.dtype),
)