Spaces:
Running on Zero
Running on Zero
| import torch | |
| from einops import rearrange | |
| from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND | |
| from ltx_core.model.upsampler.res_block import ResBlock | |
| from ltx_core.model.upsampler.spatial_rational_resampler import SpatialRationalResampler | |
| from ltx_core.model.video_vae import VideoEncoder | |
| class LatentUpsampler(torch.nn.Module): | |
| """ | |
| Model to upsample VAE latents spatially and/or temporally. | |
| Args: | |
| in_channels (`int`): Number of channels in the input latent | |
| mid_channels (`int`): Number of channels in the middle layers | |
| num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) | |
| dims (`int`): Number of dimensions for convolutions (2 or 3) | |
| spatial_upsample (`bool`): Whether to spatially upsample the latent | |
| temporal_upsample (`bool`): Whether to temporally upsample the latent | |
| spatial_scale (`float`): Scale factor for spatial upsampling | |
| rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 128, | |
| mid_channels: int = 512, | |
| num_blocks_per_stage: int = 4, | |
| dims: int = 3, | |
| spatial_upsample: bool = True, | |
| temporal_upsample: bool = False, | |
| spatial_scale: float = 2.0, | |
| rational_resampler: bool = False, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.mid_channels = mid_channels | |
| self.num_blocks_per_stage = num_blocks_per_stage | |
| self.dims = dims | |
| self.spatial_upsample = spatial_upsample | |
| self.temporal_upsample = temporal_upsample | |
| self.spatial_scale = float(spatial_scale) | |
| self.rational_resampler = rational_resampler | |
| conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d | |
| self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1) | |
| self.initial_norm = torch.nn.GroupNorm(32, mid_channels) | |
| self.initial_activation = torch.nn.SiLU() | |
| self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) | |
| if spatial_upsample and temporal_upsample: | |
| self.upsampler = torch.nn.Sequential( | |
| torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), | |
| PixelShuffleND(3), | |
| ) | |
| elif spatial_upsample: | |
| if rational_resampler: | |
| self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale) | |
| else: | |
| self.upsampler = torch.nn.Sequential( | |
| torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), | |
| PixelShuffleND(2), | |
| ) | |
| elif temporal_upsample: | |
| self.upsampler = torch.nn.Sequential( | |
| torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), | |
| PixelShuffleND(1), | |
| ) | |
| else: | |
| raise ValueError("Either spatial_upsample or temporal_upsample must be True") | |
| self.post_upsample_res_blocks = torch.nn.ModuleList( | |
| [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] | |
| ) | |
| self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1) | |
| def forward(self, latent: torch.Tensor) -> torch.Tensor: | |
| b, _, f, _, _ = latent.shape | |
| if self.dims == 2: | |
| x = rearrange(latent, "b c f h w -> (b f) c h w") | |
| x = self.initial_conv(x) | |
| x = self.initial_norm(x) | |
| x = self.initial_activation(x) | |
| for block in self.res_blocks: | |
| x = block(x) | |
| x = self.upsampler(x) | |
| for block in self.post_upsample_res_blocks: | |
| x = block(x) | |
| x = self.final_conv(x) | |
| x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) | |
| else: | |
| x = self.initial_conv(latent) | |
| x = self.initial_norm(x) | |
| x = self.initial_activation(x) | |
| for block in self.res_blocks: | |
| x = block(x) | |
| if self.temporal_upsample: | |
| x = self.upsampler(x) | |
| # remove the first frame after upsampling. | |
| # This is done because the first frame encodes one pixel frame. | |
| x = x[:, :, 1:, :, :] | |
| elif isinstance(self.upsampler, SpatialRationalResampler): | |
| x = self.upsampler(x) | |
| else: | |
| x = rearrange(x, "b c f h w -> (b f) c h w") | |
| x = self.upsampler(x) | |
| x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) | |
| for block in self.post_upsample_res_blocks: | |
| x = block(x) | |
| x = self.final_conv(x) | |
| return x | |
| def upsample_video(latent: torch.Tensor, video_encoder: VideoEncoder, upsampler: "LatentUpsampler") -> torch.Tensor: | |
| """ | |
| Apply upsampling to the latent representation using the provided upsampler, | |
| with normalization and un-normalization based on the video encoder's per-channel statistics. | |
| Args: | |
| latent: Input latent tensor of shape [B, C, F, H, W]. | |
| video_encoder: VideoEncoder with per_channel_statistics for normalization. | |
| upsampler: LatentUpsampler module to perform upsampling. | |
| Returns: | |
| torch.Tensor: Upsampled and re-normalized latent tensor. | |
| """ | |
| latent = video_encoder.per_channel_statistics.un_normalize(latent) | |
| latent = upsampler(latent) | |
| latent = video_encoder.per_channel_statistics.normalize(latent) | |
| return latent | |