File size: 1,088 Bytes
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from dataclasses import replace
from typing import Protocol

import torch

from ltx_core.types import LatentState


class Noiser(Protocol):
    """Protocol for adding noise to a latent state during diffusion."""

    def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...


class GaussianNoiser(Noiser):
    """Adds Gaussian noise to a latent state, scaled by the denoise mask."""

    def __init__(self, generator: torch.Generator):
        super().__init__()

        self.generator = generator

    def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
        noise = torch.randn(
            *latent_state.latent.shape,
            device=latent_state.latent.device,
            dtype=latent_state.latent.dtype,
            generator=self.generator,
        )
        scaled_mask = latent_state.denoise_mask * noise_scale
        latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
        return replace(
            latent_state,
            latent=latent.to(latent_state.latent.dtype),
        )