Spaces:
Running on Zero
Running on Zero
| from typing import List, Optional, Tuple | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.amp as amp | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from .wan_video_dit import flash_attention | |
| from ..core.device.npu_compatible_device import get_device_type | |
| from ..core.gradient import gradient_checkpoint_forward | |
| class RMSNorm_FP32(torch.nn.Module): | |
| def __init__(self, dim: int, eps: float): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| def broadcat(tensors, dim=-1): | |
| num_tensors = len(tensors) | |
| shape_lens = set(list(map(lambda t: len(t.shape), tensors))) | |
| assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" | |
| shape_len = list(shape_lens)[0] | |
| dim = (dim + shape_len) if dim < 0 else dim | |
| dims = list(zip(*map(lambda t: list(t.shape), tensors))) | |
| expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] | |
| assert all( | |
| [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] | |
| ), "invalid dimensions for broadcastable concatentation" | |
| max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) | |
| expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) | |
| expanded_dims.insert(dim, (dim, dims[dim])) | |
| expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) | |
| tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) | |
| return torch.cat(tensors, dim=dim) | |
| def rotate_half(x): | |
| x = rearrange(x, "... (d r) -> ... d r", r=2) | |
| x1, x2 = x.unbind(dim=-1) | |
| x = torch.stack((-x2, x1), dim=-1) | |
| return rearrange(x, "... d r -> ... (d r)") | |
| class RotaryPositionalEmbedding(nn.Module): | |
| def __init__(self, | |
| head_dim, | |
| cp_split_hw=None | |
| ): | |
| """Rotary positional embedding for 3D | |
| Reference : https://blog.eleuther.ai/rotary-embeddings/ | |
| Paper: https://arxiv.org/pdf/2104.09864.pdf | |
| Args: | |
| dim: Dimension of embedding | |
| base: Base value for exponential | |
| """ | |
| super().__init__() | |
| self.head_dim = head_dim | |
| assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.' | |
| self.cp_split_hw = cp_split_hw | |
| # We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels | |
| self.base = 10000 | |
| self.freqs_dict = {} | |
| def register_grid_size(self, grid_size): | |
| if grid_size not in self.freqs_dict: | |
| self.freqs_dict.update({ | |
| grid_size: self.precompute_freqs_cis_3d(grid_size) | |
| }) | |
| def precompute_freqs_cis_3d(self, grid_size): | |
| num_frames, height, width = grid_size | |
| dim_t = self.head_dim - 4 * (self.head_dim // 6) | |
| dim_h = 2 * (self.head_dim // 6) | |
| dim_w = 2 * (self.head_dim // 6) | |
| freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) | |
| freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) | |
| freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) | |
| grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32) | |
| grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32) | |
| grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32) | |
| grid_t = torch.from_numpy(grid_t).float() | |
| grid_h = torch.from_numpy(grid_h).float() | |
| grid_w = torch.from_numpy(grid_w).float() | |
| freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) | |
| freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) | |
| freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) | |
| freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) | |
| freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) | |
| freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) | |
| freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) | |
| # (T H W D) | |
| freqs = rearrange(freqs, "T H W D -> (T H W) D") | |
| # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: | |
| # with torch.no_grad(): | |
| # freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width) | |
| # freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw) | |
| # freqs = rearrange(freqs, "T H W D -> (T H W) D") | |
| return freqs | |
| def forward(self, q, k, grid_size): | |
| """3D RoPE. | |
| Args: | |
| query: [B, head, seq, head_dim] | |
| key: [B, head, seq, head_dim] | |
| Returns: | |
| query and key with the same shape as input. | |
| """ | |
| if grid_size not in self.freqs_dict: | |
| self.register_grid_size(grid_size) | |
| freqs_cis = self.freqs_dict[grid_size].to(q.device) | |
| q_, k_ = q.float(), k.float() | |
| freqs_cis = freqs_cis.float().to(q.device) | |
| cos, sin = freqs_cis.cos(), freqs_cis.sin() | |
| cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') | |
| q_ = (q_ * cos) + (rotate_half(q_) * sin) | |
| k_ = (k_ * cos) + (rotate_half(k_) * sin) | |
| return q_.type_as(q), k_.type_as(k) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| enable_flashattn3: bool = False, | |
| enable_flashattn2: bool = False, | |
| enable_xformers: bool = False, | |
| enable_bsa: bool = False, | |
| bsa_params: dict = None, | |
| cp_split_hw: Optional[List[int]] = None | |
| ) -> None: | |
| super().__init__() | |
| assert dim % num_heads == 0, "dim should be divisible by num_heads" | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.enable_flashattn3 = enable_flashattn3 | |
| self.enable_flashattn2 = enable_flashattn2 | |
| self.enable_xformers = enable_xformers | |
| self.enable_bsa = enable_bsa | |
| self.bsa_params = bsa_params | |
| self.cp_split_hw = cp_split_hw | |
| self.qkv = nn.Linear(dim, dim * 3, bias=True) | |
| self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) | |
| self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) | |
| self.proj = nn.Linear(dim, dim) | |
| self.rope_3d = RotaryPositionalEmbedding( | |
| self.head_dim, | |
| cp_split_hw=cp_split_hw | |
| ) | |
| def _process_attn(self, q, k, v, shape): | |
| q = rearrange(q, "B H S D -> B S (H D)") | |
| k = rearrange(k, "B H S D -> B S (H D)") | |
| v = rearrange(v, "B H S D -> B S (H D)") | |
| x = flash_attention(q, k, v, num_heads=self.num_heads) | |
| x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads) | |
| return x | |
| def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor: | |
| """ | |
| """ | |
| B, N, C = x.shape | |
| qkv = self.qkv(x) | |
| qkv_shape = (B, N, 3, self.num_heads, self.head_dim) | |
| qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] | |
| q, k, v = qkv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| if return_kv: | |
| k_cache, v_cache = k.clone(), v.clone() | |
| q, k = self.rope_3d(q, k, shape) | |
| # cond mode | |
| if num_cond_latents is not None and num_cond_latents > 0: | |
| num_cond_latents_thw = num_cond_latents * (N // shape[0]) | |
| # process the condition tokens | |
| q_cond = q[:, :, :num_cond_latents_thw].contiguous() | |
| k_cond = k[:, :, :num_cond_latents_thw].contiguous() | |
| v_cond = v[:, :, :num_cond_latents_thw].contiguous() | |
| x_cond = self._process_attn(q_cond, k_cond, v_cond, shape) | |
| # process the noise tokens | |
| q_noise = q[:, :, num_cond_latents_thw:].contiguous() | |
| x_noise = self._process_attn(q_noise, k, v, shape) | |
| # merge x_cond and x_noise | |
| x = torch.cat([x_cond, x_noise], dim=2).contiguous() | |
| else: | |
| x = self._process_attn(q, k, v, shape) | |
| x_output_shape = (B, N, C) | |
| x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] | |
| x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] | |
| x = self.proj(x) | |
| if return_kv: | |
| return x, (k_cache, v_cache) | |
| else: | |
| return x | |
| def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor: | |
| """ | |
| """ | |
| B, N, C = x.shape | |
| qkv = self.qkv(x) | |
| qkv_shape = (B, N, 3, self.num_heads, self.head_dim) | |
| qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] | |
| q, k, v = qkv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| T, H, W = shape | |
| k_cache, v_cache = kv_cache | |
| assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B] | |
| if k_cache.shape[0] == 1: | |
| k_cache = k_cache.repeat(B, 1, 1, 1) | |
| v_cache = v_cache.repeat(B, 1, 1, 1) | |
| if num_cond_latents is not None and num_cond_latents > 0: | |
| k_full = torch.cat([k_cache, k], dim=2).contiguous() | |
| v_full = torch.cat([v_cache, v], dim=2).contiguous() | |
| q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous() | |
| q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W)) | |
| q = q_padding[:, :, -N:].contiguous() | |
| x = self._process_attn(q, k_full, v_full, shape) | |
| x_output_shape = (B, N, C) | |
| x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] | |
| x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] | |
| x = self.proj(x) | |
| return x | |
| class MultiHeadCrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| enable_flashattn3=False, | |
| enable_flashattn2=False, | |
| enable_xformers=False, | |
| ): | |
| super(MultiHeadCrossAttention, self).__init__() | |
| assert dim % num_heads == 0, "d_model must be divisible by num_heads" | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.q_linear = nn.Linear(dim, dim) | |
| self.kv_linear = nn.Linear(dim, dim * 2) | |
| self.proj = nn.Linear(dim, dim) | |
| self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) | |
| self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) | |
| self.enable_flashattn3 = enable_flashattn3 | |
| self.enable_flashattn2 = enable_flashattn2 | |
| self.enable_xformers = enable_xformers | |
| def _process_cross_attn(self, x, cond, kv_seqlen): | |
| B, N, C = x.shape | |
| assert C == self.dim and cond.shape[2] == self.dim | |
| q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) | |
| kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) | |
| k, v = kv.unbind(2) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| q = rearrange(q, "B S H D -> B S (H D)") | |
| k = rearrange(k, "B S H D -> B S (H D)") | |
| v = rearrange(v, "B S H D -> B S (H D)") | |
| x = flash_attention(q, k, v, num_heads=self.num_heads) | |
| x = x.view(B, -1, C) | |
| x = self.proj(x) | |
| return x | |
| def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): | |
| """ | |
| x: [B, N, C] | |
| cond: [B, M, C] | |
| """ | |
| if num_cond_latents is None or num_cond_latents == 0: | |
| return self._process_cross_attn(x, cond, kv_seqlen) | |
| else: | |
| B, N, C = x.shape | |
| if num_cond_latents is not None and num_cond_latents > 0: | |
| assert shape is not None, "SHOULD pass in the shape" | |
| num_cond_latents_thw = num_cond_latents * (N // shape[0]) | |
| x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C] | |
| output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C] | |
| output = torch.cat([ | |
| torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device), | |
| output_noise | |
| ], dim=1).contiguous() | |
| else: | |
| raise NotImplementedError | |
| return output | |
| class LayerNorm_FP32(nn.LayerNorm): | |
| def __init__(self, dim, eps, elementwise_affine): | |
| super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| origin_dtype = inputs.dtype | |
| out = F.layer_norm( | |
| inputs.float(), | |
| self.normalized_shape, | |
| None if self.weight is None else self.weight.float(), | |
| None if self.bias is None else self.bias.float() , | |
| self.eps | |
| ).to(origin_dtype) | |
| return out | |
| def modulate_fp32(norm_func, x, shift, scale): | |
| # Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D) | |
| # ensure the modulation params be fp32 | |
| assert shift.dtype == torch.float32, scale.dtype == torch.float32 | |
| dtype = x.dtype | |
| x = norm_func(x.to(torch.float32)) | |
| x = x * (scale + 1) + shift | |
| x = x.to(dtype) | |
| return x | |
| class FinalLayer_FP32(nn.Module): | |
| """ | |
| The final layer of DiT. | |
| """ | |
| def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_patch = num_patch | |
| self.out_channels = out_channels | |
| self.adaln_tembed_dim = adaln_tembed_dim | |
| self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) | |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True)) | |
| def forward(self, x, t, latent_shape): | |
| # timestep shape: [B, T, C] | |
| assert t.dtype == torch.float32 | |
| B, N, C = x.shape | |
| T, _, _ = latent_shape | |
| with amp.autocast(get_device_type(), dtype=torch.float32): | |
| shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] | |
| x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) | |
| x = self.linear(x) | |
| return x | |
| class FeedForwardSwiGLU(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| multiple_of: int = 256, | |
| ffn_dim_multiplier: Optional[float] = None, | |
| ): | |
| super().__init__() | |
| hidden_dim = int(2 * hidden_dim / 3) | |
| # custom dim factor multiplier | |
| if ffn_dim_multiplier is not None: | |
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
| self.dim = dim | |
| self.hidden_dim = hidden_dim | |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
| def forward(self, x): | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| class TimestepEmbedder(nn.Module): | |
| """ | |
| Embeds scalar timesteps into vector representations. | |
| """ | |
| def __init__(self, t_embed_dim, frequency_embedding_size=256): | |
| super().__init__() | |
| self.t_embed_dim = t_embed_dim | |
| self.frequency_embedding_size = frequency_embedding_size | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, t_embed_dim, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(t_embed_dim, t_embed_dim, bias=True), | |
| ) | |
| def timestep_embedding(t, dim, max_period=10000): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param t: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an (N, D) Tensor of positional embeddings. | |
| """ | |
| half = dim // 2 | |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) | |
| freqs = freqs.to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t, dtype): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| if t_freq.dtype != dtype: | |
| t_freq = t_freq.to(dtype) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| class CaptionEmbedder(nn.Module): | |
| """ | |
| Embeds class labels into vector representations. | |
| """ | |
| def __init__(self, in_channels, hidden_size): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.hidden_size = hidden_size | |
| self.y_proj = nn.Sequential( | |
| nn.Linear(in_channels, hidden_size, bias=True), | |
| nn.GELU(approximate="tanh"), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| def forward(self, caption): | |
| B, _, N, C = caption.shape | |
| caption = self.y_proj(caption) | |
| return caption | |
| class PatchEmbed3D(nn.Module): | |
| """Video to Patch Embedding. | |
| Args: | |
| patch_size (int): Patch token size. Default: (2,4,4). | |
| in_chans (int): Number of input video channels. Default: 3. | |
| embed_dim (int): Number of linear projection output channels. Default: 96. | |
| norm_layer (nn.Module, optional): Normalization layer. Default: None | |
| """ | |
| def __init__( | |
| self, | |
| patch_size=(2, 4, 4), | |
| in_chans=3, | |
| embed_dim=96, | |
| norm_layer=None, | |
| flatten=True, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.flatten = flatten | |
| self.in_chans = in_chans | |
| self.embed_dim = embed_dim | |
| self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| if norm_layer is not None: | |
| self.norm = norm_layer(embed_dim) | |
| else: | |
| self.norm = None | |
| def forward(self, x): | |
| """Forward function.""" | |
| # padding | |
| _, _, D, H, W = x.size() | |
| if W % self.patch_size[2] != 0: | |
| x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) | |
| if H % self.patch_size[1] != 0: | |
| x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) | |
| if D % self.patch_size[0] != 0: | |
| x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) | |
| B, C, T, H, W = x.shape | |
| x = self.proj(x) # (B C T H W) | |
| if self.norm is not None: | |
| D, Wh, Ww = x.size(2), x.size(3), x.size(4) | |
| x = x.flatten(2).transpose(1, 2) | |
| x = self.norm(x) | |
| x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) | |
| if self.flatten: | |
| x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC | |
| return x | |
| class LongCatSingleStreamBlock(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int, | |
| mlp_ratio: int, | |
| adaln_tembed_dim: int, | |
| enable_flashattn3: bool = False, | |
| enable_flashattn2: bool = False, | |
| enable_xformers: bool = False, | |
| enable_bsa: bool = False, | |
| bsa_params=None, | |
| cp_split_hw=None | |
| ): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| # scale and gate modulation | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True) | |
| ) | |
| self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) | |
| self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) | |
| self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True) | |
| self.attn = Attention( | |
| dim=hidden_size, | |
| num_heads=num_heads, | |
| enable_flashattn3=enable_flashattn3, | |
| enable_flashattn2=enable_flashattn2, | |
| enable_xformers=enable_xformers, | |
| enable_bsa=enable_bsa, | |
| bsa_params=bsa_params, | |
| cp_split_hw=cp_split_hw | |
| ) | |
| self.cross_attn = MultiHeadCrossAttention( | |
| dim=hidden_size, | |
| num_heads=num_heads, | |
| enable_flashattn3=enable_flashattn3, | |
| enable_flashattn2=enable_flashattn2, | |
| enable_xformers=enable_xformers, | |
| ) | |
| self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio)) | |
| def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False): | |
| """ | |
| x: [B, N, C] | |
| y: [1, N_valid_tokens, C] | |
| t: [B, T, C_t] | |
| y_seqlen: [B]; type of a list | |
| latent_shape: latent shape of a single item | |
| """ | |
| x_dtype = x.dtype | |
| B, N, C = x.shape | |
| T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. | |
| # compute modulation params in fp32 | |
| with amp.autocast(device_type=get_device_type(), dtype=torch.float32): | |
| shift_msa, scale_msa, gate_msa, \ | |
| shift_mlp, scale_mlp, gate_mlp = \ | |
| self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] | |
| # self attn with modulation | |
| x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C) | |
| if kv_cache is not None: | |
| kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device)) | |
| attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache) | |
| else: | |
| attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv) | |
| if return_kv: | |
| x_s, kv_cache = attn_outputs | |
| else: | |
| x_s = attn_outputs | |
| with amp.autocast(device_type=get_device_type(), dtype=torch.float32): | |
| x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] | |
| x = x.to(x_dtype) | |
| # cross attn | |
| if not skip_crs_attn: | |
| if kv_cache is not None: | |
| num_cond_latents = None | |
| x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape) | |
| # ffn with modulation | |
| x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) | |
| x_s = self.ffn(x_m) | |
| with amp.autocast(device_type=get_device_type(), dtype=torch.float32): | |
| x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] | |
| x = x.to(x_dtype) | |
| if return_kv: | |
| return x, kv_cache | |
| else: | |
| return x | |
| class LongCatVideoTransformer3DModel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 16, | |
| out_channels: int = 16, | |
| hidden_size: int = 4096, | |
| depth: int = 48, | |
| num_heads: int = 32, | |
| caption_channels: int = 4096, | |
| mlp_ratio: int = 4, | |
| adaln_tembed_dim: int = 512, | |
| frequency_embedding_size: int = 256, | |
| # default params | |
| patch_size: Tuple[int] = (1, 2, 2), | |
| # attention config | |
| enable_flashattn3: bool = False, | |
| enable_flashattn2: bool = True, | |
| enable_xformers: bool = False, | |
| enable_bsa: bool = False, | |
| bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]}, | |
| cp_split_hw: Optional[List[int]] = [1, 1], | |
| text_tokens_zero_pad: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.cp_split_hw = cp_split_hw | |
| self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) | |
| self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size) | |
| self.y_embedder = CaptionEmbedder( | |
| in_channels=caption_channels, | |
| hidden_size=hidden_size, | |
| ) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| LongCatSingleStreamBlock( | |
| hidden_size=hidden_size, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| adaln_tembed_dim=adaln_tembed_dim, | |
| enable_flashattn3=enable_flashattn3, | |
| enable_flashattn2=enable_flashattn2, | |
| enable_xformers=enable_xformers, | |
| enable_bsa=enable_bsa, | |
| bsa_params=bsa_params, | |
| cp_split_hw=cp_split_hw | |
| ) | |
| for i in range(depth) | |
| ] | |
| ) | |
| self.final_layer = FinalLayer_FP32( | |
| hidden_size, | |
| np.prod(self.patch_size), | |
| out_channels, | |
| adaln_tembed_dim, | |
| ) | |
| self.gradient_checkpointing = False | |
| self.text_tokens_zero_pad = text_tokens_zero_pad | |
| self.lora_dict = {} | |
| self.active_loras = [] | |
| def enable_loras(self, lora_key_list=[]): | |
| self.disable_all_loras() | |
| module_loras = {} # {module_name: [lora1, lora2, ...]} | |
| model_device = next(self.parameters()).device | |
| model_dtype = next(self.parameters()).dtype | |
| for lora_key in lora_key_list: | |
| if lora_key in self.lora_dict: | |
| for lora in self.lora_dict[lora_key].loras: | |
| lora.to(model_device, dtype=model_dtype, non_blocking=True) | |
| module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".") | |
| if module_name not in module_loras: | |
| module_loras[module_name] = [] | |
| module_loras[module_name].append(lora) | |
| self.active_loras.append(lora_key) | |
| for module_name, loras in module_loras.items(): | |
| module = self._get_module_by_name(module_name) | |
| if not hasattr(module, 'org_forward'): | |
| module.org_forward = module.forward | |
| module.forward = self._create_multi_lora_forward(module, loras) | |
| def _create_multi_lora_forward(self, module, loras): | |
| def multi_lora_forward(x, *args, **kwargs): | |
| weight_dtype = x.dtype | |
| org_output = module.org_forward(x, *args, **kwargs) | |
| total_lora_output = 0 | |
| for lora in loras: | |
| if lora.use_lora: | |
| lx = lora.lora_down(x.to(lora.lora_down.weight.dtype)) | |
| lx = lora.lora_up(lx) | |
| lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale | |
| total_lora_output += lora_output | |
| return org_output + total_lora_output | |
| return multi_lora_forward | |
| def _get_module_by_name(self, module_name): | |
| try: | |
| module = self | |
| for part in module_name.split('.'): | |
| module = getattr(module, part) | |
| return module | |
| except AttributeError as e: | |
| raise ValueError(f"Cannot find module: {module_name}, error: {e}") | |
| def disable_all_loras(self): | |
| for name, module in self.named_modules(): | |
| if hasattr(module, 'org_forward'): | |
| module.forward = module.org_forward | |
| delattr(module, 'org_forward') | |
| for lora_key, lora_network in self.lora_dict.items(): | |
| for lora in lora_network.loras: | |
| lora.to("cpu") | |
| self.active_loras.clear() | |
| def enable_bsa(self,): | |
| for block in self.blocks: | |
| block.attn.enable_bsa = True | |
| def disable_bsa(self,): | |
| for block in self.blocks: | |
| block.attn.enable_bsa = False | |
| def forward( | |
| self, | |
| hidden_states, | |
| timestep, | |
| encoder_hidden_states, | |
| encoder_attention_mask=None, | |
| num_cond_latents=0, | |
| return_kv=False, | |
| kv_cache_dict={}, | |
| skip_crs_attn=False, | |
| offload_kv_cache=False, | |
| use_gradient_checkpointing=False, | |
| use_gradient_checkpointing_offload=False, | |
| ): | |
| B, _, T, H, W = hidden_states.shape | |
| N_t = T // self.patch_size[0] | |
| N_h = H // self.patch_size[1] | |
| N_w = W // self.patch_size[2] | |
| assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension." | |
| # expand the shape of timestep from [B] to [B, T] | |
| if len(timestep.shape) == 1: | |
| timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T] | |
| timestep[:, :num_cond_latents] = 0 | |
| dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(dtype) | |
| timestep = timestep.to(dtype) | |
| encoder_hidden_states = encoder_hidden_states.to(dtype) | |
| hidden_states = self.x_embedder(hidden_states) # [B, N, C] | |
| with amp.autocast(device_type=get_device_type(), dtype=torch.float32): | |
| t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] | |
| encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] | |
| if self.text_tokens_zero_pad and encoder_attention_mask is not None: | |
| encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None] | |
| encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype) | |
| if encoder_attention_mask is not None: | |
| encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1) | |
| encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C] | |
| y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B] | |
| else: | |
| y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0] | |
| encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1]) | |
| # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: | |
| # hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w) | |
| # hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw) | |
| # hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C") | |
| # blocks | |
| kv_cache_dict_ret = {} | |
| for i, block in enumerate(self.blocks): | |
| block_outputs = gradient_checkpoint_forward( | |
| block, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| x=hidden_states, | |
| y=encoder_hidden_states, | |
| t=t, | |
| y_seqlen=y_seqlens, | |
| latent_shape=(N_t, N_h, N_w), | |
| num_cond_latents=num_cond_latents, | |
| return_kv=return_kv, | |
| kv_cache=kv_cache_dict.get(i, None), | |
| skip_crs_attn=skip_crs_attn, | |
| ) | |
| if return_kv: | |
| hidden_states, kv_cache = block_outputs | |
| if offload_kv_cache: | |
| kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu()) | |
| else: | |
| kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous()) | |
| else: | |
| hidden_states = block_outputs | |
| hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out] | |
| # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: | |
| # hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw) | |
| hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W] | |
| # cast to float32 for better accuracy | |
| hidden_states = hidden_states.to(torch.float32) | |
| if return_kv: | |
| return hidden_states, kv_cache_dict_ret | |
| else: | |
| return hidden_states | |
| def unpatchify(self, x, N_t, N_h, N_w): | |
| """ | |
| Args: | |
| x (torch.Tensor): of shape [B, N, C] | |
| Return: | |
| x (torch.Tensor): of shape [B, C_out, T, H, W] | |
| """ | |
| T_p, H_p, W_p = self.patch_size | |
| x = rearrange( | |
| x, | |
| "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", | |
| N_t=N_t, | |
| N_h=N_h, | |
| N_w=N_w, | |
| T_p=T_p, | |
| H_p=H_p, | |
| W_p=W_p, | |
| C_out=self.out_channels, | |
| ) | |
| return x | |
| def state_dict_converter(): | |
| return LongCatVideoTransformer3DModelDictConverter() | |
| class LongCatVideoTransformer3DModelDictConverter: | |
| def __init__(self): | |
| pass | |
| def from_diffusers(self, state_dict): | |
| return state_dict | |
| def from_civitai(self, state_dict): | |
| return state_dict | |