| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from einops import rearrange |
| from torch import Tensor, nn |
| from torch.nn.functional import silu as swish |
|
|
| from opensora.registry import MODELS |
| from opensora.utils.ckpt import load_checkpoint |
|
|
| from .utils import DiagonalGaussianDistribution |
|
|
|
|
| @dataclass |
| class AutoEncoderConfig: |
| from_pretrained: str | None |
| cache_dir: str | None |
| resolution: int |
| in_channels: int |
| ch: int |
| out_ch: int |
| ch_mult: list[int] |
| num_res_blocks: int |
| z_channels: int |
| scale_factor: float |
| shift_factor: float |
| sample: bool = True |
|
|
|
|
| class AttnBlock(nn.Module): |
| def __init__(self, in_channels: int): |
| super().__init__() |
| self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) |
| self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) |
| self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) |
| self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) |
| self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) |
|
|
| def attention(self, h_: Tensor) -> Tensor: |
| h_ = self.norm(h_) |
| q = self.q(h_) |
| k = self.k(h_) |
| v = self.v(h_) |
| b, c, h, w = q.shape |
| q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() |
| k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() |
| v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() |
| h_ = nn.functional.scaled_dot_product_attention(q, k, v) |
| return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return x + self.proj_out(self.attention(x)) |
|
|
|
|
| class ResnetBlock(nn.Module): |
| def __init__(self, in_channels: int, out_channels: int): |
| super().__init__() |
| self.in_channels = in_channels |
| out_channels = in_channels if out_channels is None else out_channels |
| self.out_channels = out_channels |
|
|
| self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| if self.in_channels != self.out_channels: |
| self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
|
|
| def forward(self, x): |
| h = x |
| h = self.norm1(h) |
| h = swish(h) |
| h = self.conv1(h) |
|
|
| h = self.norm2(h) |
| h = swish(h) |
| h = self.conv2(h) |
|
|
| if self.in_channels != self.out_channels: |
| x = self.nin_shortcut(x) |
|
|
| return x + h |
|
|
|
|
| class Downsample(nn.Module): |
| def __init__(self, in_channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| pad = (0, 1, 0, 1) |
| x = nn.functional.pad(x, pad, mode="constant", value=0) |
| return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, in_channels: int): |
| super().__init__() |
| self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
| return self.conv(x) |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__(self, config: AutoEncoderConfig): |
| super().__init__() |
| self.ch = config.ch |
| self.num_resolutions = len(config.ch_mult) |
| self.num_res_blocks = config.num_res_blocks |
| self.resolution = config.resolution |
| self.in_channels = config.in_channels |
|
|
| |
| self.conv_in = nn.Conv2d(config.in_channels, self.ch, kernel_size=3, stride=1, padding=1) |
|
|
| curr_res = config.resolution |
| in_ch_mult = (1,) + tuple(config.ch_mult) |
| self.in_ch_mult = in_ch_mult |
| self.down = nn.ModuleList() |
| block_in = self.ch |
| for i_level in range(self.num_resolutions): |
| block = nn.ModuleList() |
| attn = nn.ModuleList() |
| block_in = config.ch * in_ch_mult[i_level] |
| block_out = config.ch * config.ch_mult[i_level] |
| for _ in range(self.num_res_blocks): |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) |
| block_in = block_out |
| down = nn.Module() |
| down.block = block |
| down.attn = attn |
| if i_level != self.num_resolutions - 1: |
| down.downsample = Downsample(block_in) |
| curr_res = curr_res // 2 |
| self.down.append(down) |
|
|
| |
| self.mid = nn.Module() |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) |
| self.mid.attn_1 = AttnBlock(block_in) |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) |
|
|
| |
| self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) |
| self.conv_out = nn.Conv2d(block_in, 2 * config.z_channels, kernel_size=3, stride=1, padding=1) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| hs = [self.conv_in(x)] |
| for i_level in range(self.num_resolutions): |
| for i_block in range(self.num_res_blocks): |
| h = self.down[i_level].block[i_block](hs[-1]) |
| if len(self.down[i_level].attn) > 0: |
| h = self.down[i_level].attn[i_block](h) |
| hs.append(h) |
| if i_level != self.num_resolutions - 1: |
| hs.append(self.down[i_level].downsample(hs[-1])) |
|
|
| |
| h = hs[-1] |
| h = self.mid.block_1(h) |
| h = self.mid.attn_1(h) |
| h = self.mid.block_2(h) |
| |
| h = self.norm_out(h) |
| h = swish(h) |
| h = self.conv_out(h) |
| return h |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__(self, config: AutoEncoderConfig): |
| super().__init__() |
| self.ch = config.ch |
| self.num_resolutions = len(config.ch_mult) |
| self.num_res_blocks = config.num_res_blocks |
| self.resolution = config.resolution |
| self.in_channels = config.in_channels |
| self.ffactor = 2 ** (self.num_resolutions - 1) |
|
|
| block_in = config.ch * config.ch_mult[self.num_resolutions - 1] |
| curr_res = config.resolution // 2 ** (self.num_resolutions - 1) |
| self.z_shape = (1, config.z_channels, curr_res, curr_res) |
|
|
| |
| self.conv_in = nn.Conv2d(config.z_channels, block_in, kernel_size=3, stride=1, padding=1) |
|
|
| |
| self.mid = nn.Module() |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) |
| self.mid.attn_1 = AttnBlock(block_in) |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) |
|
|
| |
| self.up = nn.ModuleList() |
| for i_level in reversed(range(self.num_resolutions)): |
| block = nn.ModuleList() |
| attn = nn.ModuleList() |
| block_out = config.ch * config.ch_mult[i_level] |
| for _ in range(self.num_res_blocks + 1): |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) |
| block_in = block_out |
| up = nn.Module() |
| up.block = block |
| up.attn = attn |
| if i_level != 0: |
| up.upsample = Upsample(block_in) |
| curr_res = curr_res * 2 |
| self.up.insert(0, up) |
|
|
| |
| self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) |
| self.conv_out = nn.Conv2d(block_in, config.out_ch, kernel_size=3, stride=1, padding=1) |
|
|
| def forward(self, z: Tensor) -> Tensor: |
| |
| h = self.conv_in(z) |
|
|
| |
| h = self.mid.block_1(h) |
| h = self.mid.attn_1(h) |
| h = self.mid.block_2(h) |
|
|
| |
| for i_level in reversed(range(self.num_resolutions)): |
| for i_block in range(self.num_res_blocks + 1): |
| h = self.up[i_level].block[i_block](h) |
| if len(self.up[i_level].attn) > 0: |
| h = self.up[i_level].attn[i_block](h) |
| if i_level != 0: |
| h = self.up[i_level].upsample(h) |
|
|
| |
| h = self.norm_out(h) |
| h = swish(h) |
| return self.conv_out(h) |
|
|
|
|
| class AutoEncoder(nn.Module): |
| def __init__(self, config: AutoEncoderConfig): |
| super().__init__() |
| self.encoder = Encoder(config) |
| self.decoder = Decoder(config) |
| self.scale_factor = config.scale_factor |
| self.shift_factor = config.shift_factor |
| self.sample = config.sample |
|
|
| def encode_(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution]: |
| T = x.shape[2] |
| x = rearrange(x, "b c t h w -> (b t) c h w") |
| params = self.encoder(x) |
| params = rearrange(params, "(b t) c h w -> b c t h w", t=T) |
| posterior = DiagonalGaussianDistribution(params) |
| if self.sample: |
| z = posterior.sample() |
| else: |
| z = posterior.mode() |
| z = self.scale_factor * (z - self.shift_factor) |
| return z, posterior |
|
|
| def encode(self, x: Tensor) -> Tensor: |
| return self.encode_(x)[0] |
|
|
| def decode(self, z: Tensor) -> Tensor: |
| T = z.shape[2] |
| z = rearrange(z, "b c t h w -> (b t) c h w") |
| z = z / self.scale_factor + self.shift_factor |
| x = self.decoder(z) |
| x = rearrange(x, "(b t) c h w -> b c t h w", t=T) |
| return x |
|
|
| def forward(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution, Tensor]: |
| |
| x.shape[2] |
| z, posterior = self.encode_(x) |
| |
| x_rec = self.decode(z) |
|
|
| return x_rec, posterior, z |
|
|
| def get_last_layer(self): |
| return self.decoder.conv_out.weight |
|
|
|
|
| @MODELS.register_module("autoencoder_2d") |
| def AutoEncoderFlux( |
| from_pretrained: str, |
| cache_dir=None, |
| resolution=256, |
| in_channels=3, |
| ch=128, |
| out_ch=3, |
| ch_mult=[1, 2, 4, 4], |
| num_res_blocks=2, |
| z_channels=16, |
| scale_factor=0.3611, |
| shift_factor=0.1159, |
| device_map: str | torch.device = "cuda", |
| torch_dtype: torch.dtype = torch.bfloat16, |
| ) -> AutoEncoder: |
| config = AutoEncoderConfig( |
| from_pretrained=from_pretrained, |
| cache_dir=cache_dir, |
| resolution=resolution, |
| in_channels=in_channels, |
| ch=ch, |
| out_ch=out_ch, |
| ch_mult=ch_mult, |
| num_res_blocks=num_res_blocks, |
| z_channels=z_channels, |
| scale_factor=scale_factor, |
| shift_factor=shift_factor, |
| ) |
| with torch.device(device_map): |
| model = AutoEncoder(config).to(torch_dtype) |
| if from_pretrained: |
| model = load_checkpoint(model, from_pretrained, cache_dir=cache_dir, device_map=device_map) |
| return model |
|
|