Spaces:
Running on Zero
Running on Zero
File size: 1,088 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | from dataclasses import replace
from typing import Protocol
import torch
from ltx_core.types import LatentState
class Noiser(Protocol):
"""Protocol for adding noise to a latent state during diffusion."""
def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
class GaussianNoiser(Noiser):
"""Adds Gaussian noise to a latent state, scaled by the denoise mask."""
def __init__(self, generator: torch.Generator):
super().__init__()
self.generator = generator
def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
noise = torch.randn(
*latent_state.latent.shape,
device=latent_state.latent.device,
dtype=latent_state.latent.dtype,
generator=self.generator,
)
scaled_mask = latent_state.denoise_mask * noise_scale
latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
return replace(
latent_state,
latent=latent.to(latent_state.latent.dtype),
)
|