| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
| import numpy as np |
| from tqdm import tqdm |
| from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D |
|
|
|
|
| class DownsampleCausal3D(nn.Module): |
|
|
| def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2): |
| super().__init__() |
| self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias) |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.conv(hidden_states) |
| return hidden_states |
|
|
|
|
| class DownEncoderBlockCausal3D(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| dropout=0.0, |
| num_layers=1, |
| eps=1e-6, |
| num_groups=32, |
| add_downsample=True, |
| downsample_stride=2, |
| ): |
|
|
| super().__init__() |
| resnets = [] |
| for i in range(num_layers): |
| cur_in_channel = in_channels if i == 0 else out_channels |
| resnets.append( |
| ResnetBlockCausal3D( |
| in_channels=cur_in_channel, |
| out_channels=out_channels, |
| groups=num_groups, |
| dropout=dropout, |
| eps=eps, |
| )) |
| self.resnets = nn.ModuleList(resnets) |
|
|
| self.downsamplers = None |
| if add_downsample: |
| self.downsamplers = nn.ModuleList([DownsampleCausal3D( |
| out_channels, |
| out_channels, |
| stride=downsample_stride, |
| )]) |
|
|
| def forward(self, hidden_states): |
| for resnet in self.resnets: |
| hidden_states = resnet(hidden_states) |
|
|
| if self.downsamplers is not None: |
| for downsampler in self.downsamplers: |
| hidden_states = downsampler(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class EncoderCausal3D(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels: int = 3, |
| out_channels: int = 16, |
| eps=1e-6, |
| dropout=0.0, |
| block_out_channels=[128, 256, 512, 512], |
| layers_per_block=2, |
| num_groups=32, |
| time_compression_ratio: int = 4, |
| spatial_compression_ratio: int = 8, |
| gradient_checkpointing=False, |
| ): |
| super().__init__() |
| self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) |
| self.down_blocks = nn.ModuleList([]) |
|
|
| |
| output_channel = block_out_channels[0] |
| for i in range(len(block_out_channels)): |
| input_channel = output_channel |
| output_channel = block_out_channels[i] |
| is_final_block = i == len(block_out_channels) - 1 |
| num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) |
| num_time_downsample_layers = int(np.log2(time_compression_ratio)) |
|
|
| add_spatial_downsample = bool(i < num_spatial_downsample_layers) |
| add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block) |
|
|
| downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) |
| downsample_stride_T = (2,) if add_time_downsample else (1,) |
| downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) |
| down_block = DownEncoderBlockCausal3D( |
| in_channels=input_channel, |
| out_channels=output_channel, |
| dropout=dropout, |
| num_layers=layers_per_block, |
| eps=eps, |
| num_groups=num_groups, |
| add_downsample=bool(add_spatial_downsample or add_time_downsample), |
| downsample_stride=downsample_stride, |
| ) |
| self.down_blocks.append(down_block) |
|
|
| |
| self.mid_block = UNetMidBlockCausal3D( |
| in_channels=block_out_channels[-1], |
| dropout=dropout, |
| eps=eps, |
| num_groups=num_groups, |
| attention_head_dim=block_out_channels[-1], |
| ) |
| |
| self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps) |
| self.conv_act = nn.SiLU() |
| self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3) |
|
|
| self.gradient_checkpointing = gradient_checkpointing |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.conv_in(hidden_states) |
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
|
|
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| |
| for down_block in self.down_blocks: |
| torch.utils.checkpoint.checkpoint( |
| create_custom_forward(down_block), |
| hidden_states, |
| use_reentrant=False, |
| ) |
| |
| hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(self.mid_block), |
| hidden_states, |
| use_reentrant=False, |
| ) |
| else: |
| |
| for down_block in self.down_blocks: |
| hidden_states = down_block(hidden_states) |
| |
| hidden_states = self.mid_block(hidden_states) |
| |
| hidden_states = self.conv_norm_out(hidden_states) |
| hidden_states = self.conv_act(hidden_states) |
| hidden_states = self.conv_out(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class HunyuanVideoVAEEncoder(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels=3, |
| out_channels=16, |
| eps=1e-6, |
| dropout=0.0, |
| block_out_channels=[128, 256, 512, 512], |
| layers_per_block=2, |
| num_groups=32, |
| time_compression_ratio=4, |
| spatial_compression_ratio=8, |
| gradient_checkpointing=False, |
| ): |
| super().__init__() |
| self.encoder = EncoderCausal3D( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| eps=eps, |
| dropout=dropout, |
| block_out_channels=block_out_channels, |
| layers_per_block=layers_per_block, |
| num_groups=num_groups, |
| time_compression_ratio=time_compression_ratio, |
| spatial_compression_ratio=spatial_compression_ratio, |
| gradient_checkpointing=gradient_checkpointing, |
| ) |
| self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1) |
| self.scaling_factor = 0.476986 |
|
|
|
|
| def forward(self, images): |
| latents = self.encoder(images) |
| latents = self.quant_conv(latents) |
| latents = latents[:, :16] |
| latents = latents * self.scaling_factor |
| return latents |
| |
|
|
| def build_1d_mask(self, length, left_bound, right_bound, border_width): |
| x = torch.ones((length,)) |
| if not left_bound: |
| x[:border_width] = (torch.arange(border_width) + 1) / border_width |
| if not right_bound: |
| x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) |
| return x |
| |
|
|
| def build_mask(self, data, is_bound, border_width): |
| _, _, T, H, W = data.shape |
| t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) |
| h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1]) |
| w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2]) |
|
|
| t = repeat(t, "T -> T H W", T=T, H=H, W=W) |
| h = repeat(h, "H -> T H W", T=T, H=H, W=W) |
| w = repeat(w, "W -> T H W", T=T, H=H, W=W) |
|
|
| mask = torch.stack([t, h, w]).min(dim=0).values |
| mask = rearrange(mask, "T H W -> 1 1 T H W") |
| return mask |
| |
|
|
| def tile_forward(self, hidden_states, tile_size, tile_stride): |
| B, C, T, H, W = hidden_states.shape |
| size_t, size_h, size_w = tile_size |
| stride_t, stride_h, stride_w = tile_stride |
|
|
| |
| tasks = [] |
| for t in range(0, T, stride_t): |
| if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue |
| for h in range(0, H, stride_h): |
| if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue |
| for w in range(0, W, stride_w): |
| if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue |
| t_, h_, w_ = t + size_t, h + size_h, w + size_w |
| tasks.append((t, t_, h, h_, w, w_)) |
|
|
| |
| torch_dtype = self.quant_conv.weight.dtype |
| data_device = hidden_states.device |
| computation_device = self.quant_conv.weight.device |
|
|
| weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device) |
| values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device) |
|
|
| for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): |
| hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device) |
| hidden_states_batch = self.forward(hidden_states_batch).to(data_device) |
| if t > 0: |
| hidden_states_batch = hidden_states_batch[:, :, 1:] |
|
|
| mask = self.build_mask( |
| hidden_states_batch, |
| is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W), |
| border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8) |
| ).to(dtype=torch_dtype, device=data_device) |
|
|
| target_t = 0 if t==0 else t // 4 + 1 |
| target_h = h // 8 |
| target_w = w // 8 |
| values[ |
| :, |
| :, |
| target_t: target_t + hidden_states_batch.shape[2], |
| target_h: target_h + hidden_states_batch.shape[3], |
| target_w: target_w + hidden_states_batch.shape[4], |
| ] += hidden_states_batch * mask |
| weight[ |
| :, |
| :, |
| target_t: target_t + hidden_states_batch.shape[2], |
| target_h: target_h + hidden_states_batch.shape[3], |
| target_w: target_w + hidden_states_batch.shape[4], |
| ] += mask |
| return values / weight |
|
|
|
|
| def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)): |
| latents = latents.to(self.quant_conv.weight.dtype) |
| return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride) |
|
|
|
|
| @staticmethod |
| def state_dict_converter(): |
| return HunyuanVideoVAEEncoderStateDictConverter() |
|
|
|
|
| class HunyuanVideoVAEEncoderStateDictConverter: |
|
|
| def __init__(self): |
| pass |
|
|
| def from_diffusers(self, state_dict): |
| state_dict_ = {} |
| for name in state_dict: |
| if name.startswith('encoder.') or name.startswith('quant_conv.'): |
| state_dict_[name] = state_dict[name] |
| return state_dict_ |
|
|