| """ |
| MIT License |
| |
| Copyright (c) 2021 Wilson Yan |
| |
| Permission is hereby granted, free of charge, to any person obtaining a copy |
| of this software and associated documentation files (the "Software"), to deal |
| in the Software without restriction, including without limitation the rights |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| copies of the Software, and to permit persons to whom the Software is |
| furnished to do so, subject to the following conditions: |
| |
| The above copyright notice and this permission notice shall be included in all |
| copies or substantial portions of the Software. |
| |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| SOFTWARE. |
| |
| |
| This file is copied from https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/vqvae.py |
| We adapted it to Hugging Face AutoModel for easier model loading. |
| """ |
|
|
|
|
| import os |
| import math |
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
|
|
| from .attention import MultiHeadAttention |
| from ._utils import shift_dim |
| from transformers import PreTrainedModel |
| from .configuration_vqvae import VQVAEConfig |
|
|
|
|
| class VQVAE(PreTrainedModel): |
| config_class = VQVAEConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.embedding_dim = config.embedding_dim |
| self.n_codes = config.n_codes |
|
|
| self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.downsample) |
| self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.downsample) |
|
|
| self.pre_vq_conv = SamePadConv3d(config.n_hiddens, config.embedding_dim, 1) |
| self.post_vq_conv = SamePadConv3d(config.embedding_dim, config.n_hiddens, 1) |
|
|
| self.codebook = Codebook(config.n_codes, config.embedding_dim) |
|
|
| @property |
| def latent_shape(self): |
| input_shape = (self.args.sequence_length, self.args.resolution, |
| self.args.resolution) |
| return tuple([s // d for s, d in zip(input_shape, |
| self.args.downsample)]) |
|
|
| def encode(self, x, include_embeddings=False): |
| h = self.pre_vq_conv(self.encoder(x)) |
| vq_output = self.codebook(h) |
| if include_embeddings: |
| return vq_output['encodings'], vq_output['embeddings'] |
| else: |
| return vq_output['encodings'] |
|
|
| def decode(self, encodings): |
| h = F.embedding(encodings, self.codebook.embeddings) |
| h = self.post_vq_conv(shift_dim(h, -1, 1)) |
| return self.decoder(h) |
|
|
| def decode_from_embeddings(self, embeddings): |
| |
| encodings = self.codebook.search_indices(embeddings) |
| return self.decode(encodings) |
|
|
| def forward(self, x): |
| z = self.pre_vq_conv(self.encoder(x)) |
| vq_output = self.codebook(z) |
| x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) |
| recon_loss = F.mse_loss(x_recon, x) / 0.06 |
|
|
| return recon_loss, x_recon, vq_output |
|
|
|
|
| class AxialBlock(nn.Module): |
| def __init__(self, n_hiddens, n_head): |
| super().__init__() |
| kwargs = dict(shape=(0,) * 3, dim_q=n_hiddens, |
| dim_kv=n_hiddens, n_head=n_head, |
| n_layer=1, causal=False, attn_type='axial') |
| self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2), |
| **kwargs) |
| self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3), |
| **kwargs) |
| self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4), |
| **kwargs) |
|
|
| def forward(self, x): |
| x = shift_dim(x, 1, -1) |
| x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x) |
| x = shift_dim(x, -1, 1) |
| return x |
|
|
|
|
| class AttentionResidualBlock(nn.Module): |
| def __init__(self, n_hiddens): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.BatchNorm3d(n_hiddens), |
| nn.ReLU(), |
| SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False), |
| nn.BatchNorm3d(n_hiddens // 2), |
| nn.ReLU(), |
| SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False), |
| nn.BatchNorm3d(n_hiddens), |
| nn.ReLU(), |
| AxialBlock(n_hiddens, 2) |
| ) |
|
|
| def forward(self, x): |
| return x + self.block(x) |
|
|
| class Codebook(nn.Module): |
| def __init__(self, n_codes, embedding_dim): |
| super().__init__() |
| self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) |
| self.register_buffer('N', torch.zeros(n_codes)) |
| self.register_buffer('z_avg', self.embeddings.data.clone()) |
|
|
| self.n_codes = n_codes |
| self.embedding_dim = embedding_dim |
| self._need_init = True |
|
|
| def _tile(self, x): |
| d, ew = x.shape |
| if d < self.n_codes: |
| n_repeats = (self.n_codes + d - 1) // d |
| std = 0.01 / np.sqrt(ew) |
| x = x.repeat(n_repeats, 1) |
| x = x + torch.randn_like(x) * std |
| return x |
|
|
| def _init_embeddings(self, z): |
| |
| self._need_init = False |
| flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
| y = self._tile(flat_inputs) |
|
|
| d = y.shape[0] |
| _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] |
| if dist.is_initialized(): |
| dist.broadcast(_k_rand, 0) |
| self.embeddings.data.copy_(_k_rand) |
| self.z_avg.data.copy_(_k_rand) |
| self.N.data.copy_(torch.ones(self.n_codes)) |
|
|
| def search_indices(self, z): |
| |
| flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
| distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ |
| - 2 * flat_inputs @ self.embeddings.t() \ |
| + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) |
|
|
| encoding_indices = torch.argmin(distances, dim=1) |
| encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) |
| return encoding_indices |
|
|
|
|
| def forward(self, z): |
| |
| if self._need_init and self.training: |
| self._init_embeddings(z) |
| flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
| distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ |
| - 2 * flat_inputs @ self.embeddings.t() \ |
| + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) |
|
|
| encoding_indices = torch.argmin(distances, dim=1) |
| encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) |
| encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) |
|
|
| embeddings = F.embedding(encoding_indices, self.embeddings) |
| embeddings = shift_dim(embeddings, -1, 1) |
|
|
| commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) |
|
|
| |
| if self.training: |
| n_total = encode_onehot.sum(dim=0) |
| encode_sum = flat_inputs.t() @ encode_onehot |
| if dist.is_initialized(): |
| dist.all_reduce(n_total) |
| dist.all_reduce(encode_sum) |
|
|
| self.N.data.mul_(0.99).add_(n_total, alpha=0.01) |
| self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) |
|
|
| n = self.N.sum() |
| weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n |
| encode_normalized = self.z_avg / weights.unsqueeze(1) |
| self.embeddings.data.copy_(encode_normalized) |
|
|
| y = self._tile(flat_inputs) |
| _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] |
| if dist.is_initialized(): |
| dist.broadcast(_k_rand, 0) |
|
|
| usage = (self.N.view(self.n_codes, 1) >= 1).float() |
| self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) |
|
|
| embeddings_st = (embeddings - z).detach() + z |
|
|
| avg_probs = torch.mean(encode_onehot, dim=0) |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
| return dict(embeddings=embeddings_st, encodings=encoding_indices, |
| commitment_loss=commitment_loss, perplexity=perplexity) |
|
|
| def dictionary_lookup(self, encodings): |
| embeddings = F.embedding(encodings, self.embeddings) |
| return embeddings |
|
|
| class Encoder(nn.Module): |
| def __init__(self, n_hiddens, n_res_layers, downsample): |
| super().__init__() |
| n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) |
| self.convs = nn.ModuleList() |
| max_ds = n_times_downsample.max() |
| for i in range(max_ds): |
| in_channels = 3 if i == 0 else n_hiddens |
| stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) |
| conv = SamePadConv3d(in_channels, n_hiddens, 4, stride=stride) |
| self.convs.append(conv) |
| n_times_downsample -= 1 |
| self.conv_last = SamePadConv3d(in_channels, n_hiddens, kernel_size=3) |
|
|
| self.res_stack = nn.Sequential( |
| *[AttentionResidualBlock(n_hiddens) |
| for _ in range(n_res_layers)], |
| nn.BatchNorm3d(n_hiddens), |
| nn.ReLU() |
| ) |
|
|
| def forward(self, x): |
| h = x |
| for conv in self.convs: |
| h = F.relu(conv(h)) |
| h = self.conv_last(h) |
| h = self.res_stack(h) |
| return h |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__(self, n_hiddens, n_res_layers, upsample): |
| super().__init__() |
| self.res_stack = nn.Sequential( |
| *[AttentionResidualBlock(n_hiddens) |
| for _ in range(n_res_layers)], |
| nn.BatchNorm3d(n_hiddens), |
| nn.ReLU() |
| ) |
|
|
| n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) |
| max_us = n_times_upsample.max() |
| self.convts = nn.ModuleList() |
| for i in range(max_us): |
| out_channels = 3 if i == max_us - 1 else n_hiddens |
| us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) |
| convt = SamePadConvTranspose3d(n_hiddens, out_channels, 4, |
| stride=us) |
| self.convts.append(convt) |
| n_times_upsample -= 1 |
|
|
| def forward(self, x): |
| h = self.res_stack(x) |
| for i, convt in enumerate(self.convts): |
| h = convt(h) |
| if i < len(self.convts) - 1: |
| h = F.relu(h) |
| return h |
|
|
|
|
| |
| class SamePadConv3d(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): |
| super().__init__() |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size,) * 3 |
| if isinstance(stride, int): |
| stride = (stride,) * 3 |
|
|
| |
| total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) |
| pad_input = [] |
| for p in total_pad[::-1]: |
| pad_input.append((p // 2 + p % 2, p // 2)) |
| pad_input = sum(pad_input, tuple()) |
| self.pad_input = pad_input |
|
|
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, |
| stride=stride, padding=0, bias=bias) |
|
|
| def forward(self, x): |
| return self.conv(F.pad(x, self.pad_input)) |
|
|
|
|
| class SamePadConvTranspose3d(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): |
| super().__init__() |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size,) * 3 |
| if isinstance(stride, int): |
| stride = (stride,) * 3 |
|
|
| total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) |
| pad_input = [] |
| for p in total_pad[::-1]: |
| pad_input.append((p // 2 + p % 2, p // 2)) |
| pad_input = sum(pad_input, tuple()) |
| self.pad_input = pad_input |
|
|
| self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, |
| stride=stride, bias=bias, |
| padding=tuple([k - 1 for k in kernel_size])) |
|
|
| def forward(self, x): |
| return self.convt(F.pad(x, self.pad_input)) |
|
|