Add Diffusers pipeline support

#8
by sam-motamed - opened
diffusers/cogvideox_transformer3d.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.attention import Attention, FeedForward
25
+ from diffusers.models.attention_processor import (
26
+ AttentionProcessor, CogVideoXAttnProcessor2_0,
27
+ FusedCogVideoXAttnProcessor2_0)
28
+ from diffusers.models.embeddings import (CogVideoXPatchEmbed,
29
+ TimestepEmbedding, Timesteps,
30
+ get_3d_sincos_pos_embed)
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
34
+ from diffusers.utils import is_torch_version, logging
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from torch import nn
37
+
38
+ from dist_utils import (get_sequence_parallel_rank,
39
+ get_sequence_parallel_world_size,
40
+ get_sp_group,
41
+ xFuserLongContextAttention)
42
+ from dist_utils import CogVideoXMultiGPUsAttnProcessor2_0
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ class CogVideoXPatchEmbed(nn.Module):
49
+ def __init__(
50
+ self,
51
+ patch_size: int = 2,
52
+ patch_size_t: Optional[int] = None,
53
+ in_channels: int = 16,
54
+ embed_dim: int = 1920,
55
+ text_embed_dim: int = 4096,
56
+ bias: bool = True,
57
+ sample_width: int = 90,
58
+ sample_height: int = 60,
59
+ sample_frames: int = 49,
60
+ temporal_compression_ratio: int = 4,
61
+ max_text_seq_length: int = 226,
62
+ spatial_interpolation_scale: float = 1.875,
63
+ temporal_interpolation_scale: float = 1.0,
64
+ use_positional_embeddings: bool = True,
65
+ use_learned_positional_embeddings: bool = True,
66
+ ) -> None:
67
+ super().__init__()
68
+
69
+ post_patch_height = sample_height // patch_size
70
+ post_patch_width = sample_width // patch_size
71
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
72
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
73
+ self.post_patch_height = post_patch_height
74
+ self.post_patch_width = post_patch_width
75
+ self.post_time_compression_frames = post_time_compression_frames
76
+ self.patch_size = patch_size
77
+ self.patch_size_t = patch_size_t
78
+ self.embed_dim = embed_dim
79
+ self.sample_height = sample_height
80
+ self.sample_width = sample_width
81
+ self.sample_frames = sample_frames
82
+ self.temporal_compression_ratio = temporal_compression_ratio
83
+ self.max_text_seq_length = max_text_seq_length
84
+ self.spatial_interpolation_scale = spatial_interpolation_scale
85
+ self.temporal_interpolation_scale = temporal_interpolation_scale
86
+ self.use_positional_embeddings = use_positional_embeddings
87
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
88
+
89
+ if patch_size_t is None:
90
+ # CogVideoX 1.0 checkpoints
91
+ self.proj = nn.Conv2d(
92
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
93
+ )
94
+ else:
95
+ # CogVideoX 1.5 checkpoints
96
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
97
+
98
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
99
+
100
+ if use_positional_embeddings or use_learned_positional_embeddings:
101
+ persistent = use_learned_positional_embeddings
102
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
103
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
104
+
105
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
106
+ post_patch_height = sample_height // self.patch_size
107
+ post_patch_width = sample_width // self.patch_size
108
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
109
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
110
+
111
+ pos_embedding = get_3d_sincos_pos_embed(
112
+ self.embed_dim,
113
+ (post_patch_width, post_patch_height),
114
+ post_time_compression_frames,
115
+ self.spatial_interpolation_scale,
116
+ self.temporal_interpolation_scale,
117
+ )
118
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
119
+ joint_pos_embedding = torch.zeros(
120
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
121
+ )
122
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
123
+
124
+ return joint_pos_embedding
125
+
126
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
127
+ r"""
128
+ Args:
129
+ text_embeds (`torch.Tensor`):
130
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
131
+ image_embeds (`torch.Tensor`):
132
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
133
+ """
134
+ text_embeds = self.text_proj(text_embeds)
135
+
136
+ text_batch_size, text_seq_length, text_channels = text_embeds.shape
137
+ batch_size, num_frames, channels, height, width = image_embeds.shape
138
+
139
+ if self.patch_size_t is None:
140
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
141
+ image_embeds = self.proj(image_embeds)
142
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
143
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
144
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
145
+ else:
146
+ p = self.patch_size
147
+ p_t = self.patch_size_t
148
+
149
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
150
+ # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
151
+ image_embeds = image_embeds.reshape(
152
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
153
+ )
154
+ # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
155
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
156
+ image_embeds = self.proj(image_embeds)
157
+
158
+ embeds = torch.cat(
159
+ [text_embeds, image_embeds], dim=1
160
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
161
+
162
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
163
+ seq_length = height * width * num_frames // (self.patch_size**2)
164
+ # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
165
+ pos_embeds = self.pos_embedding
166
+ emb_size = embeds.size()[-1]
167
+ pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
168
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
169
+ pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
170
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
171
+ pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
172
+ pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
173
+ embeds = embeds + pos_embeds
174
+
175
+ return embeds
176
+
177
+ @maybe_allow_in_graph
178
+ class CogVideoXBlock(nn.Module):
179
+ r"""
180
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
181
+
182
+ Parameters:
183
+ dim (`int`):
184
+ The number of channels in the input and output.
185
+ num_attention_heads (`int`):
186
+ The number of heads to use for multi-head attention.
187
+ attention_head_dim (`int`):
188
+ The number of channels in each head.
189
+ time_embed_dim (`int`):
190
+ The number of channels in timestep embedding.
191
+ dropout (`float`, defaults to `0.0`):
192
+ The dropout probability to use.
193
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
194
+ Activation function to be used in feed-forward.
195
+ attention_bias (`bool`, defaults to `False`):
196
+ Whether or not to use bias in attention projection layers.
197
+ qk_norm (`bool`, defaults to `True`):
198
+ Whether or not to use normalization after query and key projections in Attention.
199
+ norm_elementwise_affine (`bool`, defaults to `True`):
200
+ Whether to use learnable elementwise affine parameters for normalization.
201
+ norm_eps (`float`, defaults to `1e-5`):
202
+ Epsilon value for normalization layers.
203
+ final_dropout (`bool` defaults to `False`):
204
+ Whether to apply a final dropout after the last feed-forward layer.
205
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
206
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
207
+ ff_bias (`bool`, defaults to `True`):
208
+ Whether or not to use bias in Feed-forward layer.
209
+ attention_out_bias (`bool`, defaults to `True`):
210
+ Whether or not to use bias in Attention output projection layer.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ dim: int,
216
+ num_attention_heads: int,
217
+ attention_head_dim: int,
218
+ time_embed_dim: int,
219
+ dropout: float = 0.0,
220
+ activation_fn: str = "gelu-approximate",
221
+ attention_bias: bool = False,
222
+ qk_norm: bool = True,
223
+ norm_elementwise_affine: bool = True,
224
+ norm_eps: float = 1e-5,
225
+ final_dropout: bool = True,
226
+ ff_inner_dim: Optional[int] = None,
227
+ ff_bias: bool = True,
228
+ attention_out_bias: bool = True,
229
+ ):
230
+ super().__init__()
231
+
232
+ # 1. Self Attention
233
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
234
+
235
+ self.attn1 = Attention(
236
+ query_dim=dim,
237
+ dim_head=attention_head_dim,
238
+ heads=num_attention_heads,
239
+ qk_norm="layer_norm" if qk_norm else None,
240
+ eps=1e-6,
241
+ bias=attention_bias,
242
+ out_bias=attention_out_bias,
243
+ processor=CogVideoXAttnProcessor2_0(),
244
+ )
245
+
246
+ # 2. Feed Forward
247
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
248
+
249
+ self.ff = FeedForward(
250
+ dim,
251
+ dropout=dropout,
252
+ activation_fn=activation_fn,
253
+ final_dropout=final_dropout,
254
+ inner_dim=ff_inner_dim,
255
+ bias=ff_bias,
256
+ )
257
+
258
+ def forward(
259
+ self,
260
+ hidden_states: torch.Tensor,
261
+ encoder_hidden_states: torch.Tensor,
262
+ temb: torch.Tensor,
263
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
264
+ ) -> torch.Tensor:
265
+ text_seq_length = encoder_hidden_states.size(1)
266
+
267
+ # norm & modulate
268
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
269
+ hidden_states, encoder_hidden_states, temb
270
+ )
271
+
272
+ # attention
273
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
274
+ hidden_states=norm_hidden_states,
275
+ encoder_hidden_states=norm_encoder_hidden_states,
276
+ image_rotary_emb=image_rotary_emb,
277
+ )
278
+
279
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
280
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
281
+
282
+ # norm & modulate
283
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
284
+ hidden_states, encoder_hidden_states, temb
285
+ )
286
+
287
+ # feed-forward
288
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
289
+ ff_output = self.ff(norm_hidden_states)
290
+
291
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
292
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
293
+
294
+ return hidden_states, encoder_hidden_states
295
+
296
+
297
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
298
+ """
299
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
300
+
301
+ Parameters:
302
+ num_attention_heads (`int`, defaults to `30`):
303
+ The number of heads to use for multi-head attention.
304
+ attention_head_dim (`int`, defaults to `64`):
305
+ The number of channels in each head.
306
+ in_channels (`int`, defaults to `16`):
307
+ The number of channels in the input.
308
+ out_channels (`int`, *optional*, defaults to `16`):
309
+ The number of channels in the output.
310
+ flip_sin_to_cos (`bool`, defaults to `True`):
311
+ Whether to flip the sin to cos in the time embedding.
312
+ time_embed_dim (`int`, defaults to `512`):
313
+ Output dimension of timestep embeddings.
314
+ text_embed_dim (`int`, defaults to `4096`):
315
+ Input dimension of text embeddings from the text encoder.
316
+ num_layers (`int`, defaults to `30`):
317
+ The number of layers of Transformer blocks to use.
318
+ dropout (`float`, defaults to `0.0`):
319
+ The dropout probability to use.
320
+ attention_bias (`bool`, defaults to `True`):
321
+ Whether or not to use bias in the attention projection layers.
322
+ sample_width (`int`, defaults to `90`):
323
+ The width of the input latents.
324
+ sample_height (`int`, defaults to `60`):
325
+ The height of the input latents.
326
+ sample_frames (`int`, defaults to `49`):
327
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
328
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
329
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
330
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
331
+ patch_size (`int`, defaults to `2`):
332
+ The size of the patches to use in the patch embedding layer.
333
+ temporal_compression_ratio (`int`, defaults to `4`):
334
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
335
+ max_text_seq_length (`int`, defaults to `226`):
336
+ The maximum sequence length of the input text embeddings.
337
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
338
+ Activation function to use in feed-forward.
339
+ timestep_activation_fn (`str`, defaults to `"silu"`):
340
+ Activation function to use when generating the timestep embeddings.
341
+ norm_elementwise_affine (`bool`, defaults to `True`):
342
+ Whether or not to use elementwise affine in normalization layers.
343
+ norm_eps (`float`, defaults to `1e-5`):
344
+ The epsilon value to use in normalization layers.
345
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
346
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
347
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
348
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
349
+ """
350
+
351
+ _supports_gradient_checkpointing = True
352
+
353
+ @register_to_config
354
+ def __init__(
355
+ self,
356
+ num_attention_heads: int = 30,
357
+ attention_head_dim: int = 64,
358
+ in_channels: int = 16,
359
+ out_channels: Optional[int] = 16,
360
+ flip_sin_to_cos: bool = True,
361
+ freq_shift: int = 0,
362
+ time_embed_dim: int = 512,
363
+ text_embed_dim: int = 4096,
364
+ num_layers: int = 30,
365
+ dropout: float = 0.0,
366
+ attention_bias: bool = True,
367
+ sample_width: int = 90,
368
+ sample_height: int = 60,
369
+ sample_frames: int = 49,
370
+ patch_size: int = 2,
371
+ patch_size_t: Optional[int] = None,
372
+ temporal_compression_ratio: int = 4,
373
+ max_text_seq_length: int = 226,
374
+ activation_fn: str = "gelu-approximate",
375
+ timestep_activation_fn: str = "silu",
376
+ norm_elementwise_affine: bool = True,
377
+ norm_eps: float = 1e-5,
378
+ spatial_interpolation_scale: float = 1.875,
379
+ temporal_interpolation_scale: float = 1.0,
380
+ use_rotary_positional_embeddings: bool = False,
381
+ use_learned_positional_embeddings: bool = False,
382
+ patch_bias: bool = True,
383
+ add_noise_in_inpaint_model: bool = False,
384
+ ):
385
+ super().__init__()
386
+ inner_dim = num_attention_heads * attention_head_dim
387
+ self.patch_size_t = patch_size_t
388
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
389
+ raise ValueError(
390
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
391
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
392
+ "issue at https://github.com/huggingface/diffusers/issues."
393
+ )
394
+
395
+ # 1. Patch embedding
396
+ self.patch_embed = CogVideoXPatchEmbed(
397
+ patch_size=patch_size,
398
+ patch_size_t=patch_size_t,
399
+ in_channels=in_channels,
400
+ embed_dim=inner_dim,
401
+ text_embed_dim=text_embed_dim,
402
+ bias=patch_bias,
403
+ sample_width=sample_width,
404
+ sample_height=sample_height,
405
+ sample_frames=sample_frames,
406
+ temporal_compression_ratio=temporal_compression_ratio,
407
+ max_text_seq_length=max_text_seq_length,
408
+ spatial_interpolation_scale=spatial_interpolation_scale,
409
+ temporal_interpolation_scale=temporal_interpolation_scale,
410
+ use_positional_embeddings=not use_rotary_positional_embeddings,
411
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
412
+ )
413
+ self.embedding_dropout = nn.Dropout(dropout)
414
+
415
+ # 2. Time embeddings
416
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
417
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
418
+
419
+ # 3. Define spatio-temporal transformers blocks
420
+ self.transformer_blocks = nn.ModuleList(
421
+ [
422
+ CogVideoXBlock(
423
+ dim=inner_dim,
424
+ num_attention_heads=num_attention_heads,
425
+ attention_head_dim=attention_head_dim,
426
+ time_embed_dim=time_embed_dim,
427
+ dropout=dropout,
428
+ activation_fn=activation_fn,
429
+ attention_bias=attention_bias,
430
+ norm_elementwise_affine=norm_elementwise_affine,
431
+ norm_eps=norm_eps,
432
+ )
433
+ for _ in range(num_layers)
434
+ ]
435
+ )
436
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
437
+
438
+ # 4. Output blocks
439
+ self.norm_out = AdaLayerNorm(
440
+ embedding_dim=time_embed_dim,
441
+ output_dim=2 * inner_dim,
442
+ norm_elementwise_affine=norm_elementwise_affine,
443
+ norm_eps=norm_eps,
444
+ chunk_dim=1,
445
+ )
446
+
447
+ if patch_size_t is None:
448
+ # For CogVideox 1.0
449
+ output_dim = patch_size * patch_size * out_channels
450
+ else:
451
+ # For CogVideoX 1.5
452
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
453
+
454
+ self.proj_out = nn.Linear(inner_dim, output_dim)
455
+
456
+ self.gradient_checkpointing = False
457
+ self.sp_world_size = 1
458
+ self.sp_world_rank = 0
459
+
460
+ def _set_gradient_checkpointing(self, module, value=False):
461
+ self.gradient_checkpointing = value
462
+
463
+ def enable_multi_gpus_inference(self,):
464
+ self.sp_world_size = get_sequence_parallel_world_size()
465
+ self.sp_world_rank = get_sequence_parallel_rank()
466
+ self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
467
+
468
+ @property
469
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
470
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
471
+ r"""
472
+ Returns:
473
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
474
+ indexed by its weight name.
475
+ """
476
+ # set recursively
477
+ processors = {}
478
+
479
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
480
+ if hasattr(module, "get_processor"):
481
+ processors[f"{name}.processor"] = module.get_processor()
482
+
483
+ for sub_name, child in module.named_children():
484
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
485
+
486
+ return processors
487
+
488
+ for name, module in self.named_children():
489
+ fn_recursive_add_processors(name, module, processors)
490
+
491
+ return processors
492
+
493
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
494
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
495
+ r"""
496
+ Sets the attention processor to use to compute attention.
497
+
498
+ Parameters:
499
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
500
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
501
+ for **all** `Attention` layers.
502
+
503
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
504
+ processor. This is strongly recommended when setting trainable attention processors.
505
+
506
+ """
507
+ count = len(self.attn_processors.keys())
508
+
509
+ if isinstance(processor, dict) and len(processor) != count:
510
+ raise ValueError(
511
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
512
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
513
+ )
514
+
515
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
516
+ if hasattr(module, "set_processor"):
517
+ if not isinstance(processor, dict):
518
+ module.set_processor(processor)
519
+ else:
520
+ module.set_processor(processor.pop(f"{name}.processor"))
521
+
522
+ for sub_name, child in module.named_children():
523
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
524
+
525
+ for name, module in self.named_children():
526
+ fn_recursive_attn_processor(name, module, processor)
527
+
528
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
529
+ def fuse_qkv_projections(self):
530
+ """
531
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
532
+ are fused. For cross-attention modules, key and value projection matrices are fused.
533
+
534
+ <Tip warning={true}>
535
+
536
+ This API is πŸ§ͺ experimental.
537
+
538
+ </Tip>
539
+ """
540
+ self.original_attn_processors = None
541
+
542
+ for _, attn_processor in self.attn_processors.items():
543
+ if "Added" in str(attn_processor.__class__.__name__):
544
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
545
+
546
+ self.original_attn_processors = self.attn_processors
547
+
548
+ for module in self.modules():
549
+ if isinstance(module, Attention):
550
+ module.fuse_projections(fuse=True)
551
+
552
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
553
+
554
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
555
+ def unfuse_qkv_projections(self):
556
+ """Disables the fused QKV projection if enabled.
557
+
558
+ <Tip warning={true}>
559
+
560
+ This API is πŸ§ͺ experimental.
561
+
562
+ </Tip>
563
+
564
+ """
565
+ if self.original_attn_processors is not None:
566
+ self.set_attn_processor(self.original_attn_processors)
567
+
568
+ def forward(
569
+ self,
570
+ hidden_states: torch.Tensor,
571
+ encoder_hidden_states: torch.Tensor,
572
+ timestep: Union[int, float, torch.LongTensor],
573
+ timestep_cond: Optional[torch.Tensor] = None,
574
+ inpaint_latents: Optional[torch.Tensor] = None,
575
+ control_latents: Optional[torch.Tensor] = None,
576
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
577
+ return_dict: bool = True,
578
+ ):
579
+ batch_size, num_frames, channels, height, width = hidden_states.shape
580
+ if num_frames == 1 and self.patch_size_t is not None:
581
+ hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
582
+ if inpaint_latents is not None:
583
+ inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
584
+ if control_latents is not None:
585
+ control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
586
+ local_num_frames = num_frames + 1
587
+ else:
588
+ local_num_frames = num_frames
589
+
590
+ # 1. Time embedding
591
+ timesteps = timestep
592
+ t_emb = self.time_proj(timesteps)
593
+
594
+ # timesteps does not contain any weights and will always return f32 tensors
595
+ # but time_embedding might actually be running in fp16. so we need to cast here.
596
+ # there might be better ways to encapsulate this.
597
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
598
+ emb = self.time_embedding(t_emb, timestep_cond)
599
+
600
+ # 2. Patch embedding
601
+ if inpaint_latents is not None:
602
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
603
+ if control_latents is not None:
604
+ hidden_states = torch.concat([hidden_states, control_latents], 2)
605
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
606
+ hidden_states = self.embedding_dropout(hidden_states)
607
+
608
+ text_seq_length = encoder_hidden_states.shape[1]
609
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
610
+ hidden_states = hidden_states[:, text_seq_length:]
611
+
612
+ # Context Parallel
613
+ if self.sp_world_size > 1:
614
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
615
+ if image_rotary_emb is not None:
616
+ image_rotary_emb = (
617
+ torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
618
+ torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
619
+ )
620
+
621
+ # 3. Transformer blocks
622
+ for i, block in enumerate(self.transformer_blocks):
623
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
624
+
625
+ def create_custom_forward(module):
626
+ def custom_forward(*inputs):
627
+ return module(*inputs)
628
+
629
+ return custom_forward
630
+
631
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
632
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
633
+ create_custom_forward(block),
634
+ hidden_states,
635
+ encoder_hidden_states,
636
+ emb,
637
+ image_rotary_emb,
638
+ **ckpt_kwargs,
639
+ )
640
+ else:
641
+ hidden_states, encoder_hidden_states = block(
642
+ hidden_states=hidden_states,
643
+ encoder_hidden_states=encoder_hidden_states,
644
+ temb=emb,
645
+ image_rotary_emb=image_rotary_emb,
646
+ )
647
+
648
+ if not self.config.use_rotary_positional_embeddings:
649
+ # CogVideoX-2B
650
+ hidden_states = self.norm_final(hidden_states)
651
+ else:
652
+ # CogVideoX-5B
653
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
654
+ hidden_states = self.norm_final(hidden_states)
655
+ hidden_states = hidden_states[:, text_seq_length:]
656
+
657
+ # 4. Final block
658
+ hidden_states = self.norm_out(hidden_states, temb=emb)
659
+ hidden_states = self.proj_out(hidden_states)
660
+
661
+ if self.sp_world_size > 1:
662
+ hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
663
+
664
+ # 5. Unpatchify
665
+ p = self.config.patch_size
666
+ p_t = self.config.patch_size_t
667
+
668
+ if p_t is None:
669
+ output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
670
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
671
+ else:
672
+ output = hidden_states.reshape(
673
+ batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
674
+ )
675
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
676
+
677
+ if num_frames == 1:
678
+ output = output[:, :num_frames, :]
679
+
680
+ if not return_dict:
681
+ return (output,)
682
+ return Transformer2DModelOutput(sample=output)
683
+
684
+ @classmethod
685
+ def from_pretrained(
686
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
687
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16, use_vae_mask=False, stack_mask=False,
688
+ ):
689
+ if subfolder is not None:
690
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
691
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
692
+
693
+ config_file = os.path.join(pretrained_model_path, 'config.json')
694
+ if not os.path.isfile(config_file):
695
+ raise RuntimeError(f"{config_file} does not exist")
696
+ with open(config_file, "r") as f:
697
+ config = json.load(f)
698
+
699
+ if use_vae_mask:
700
+ print('[DEBUG] use vae to encode mask')
701
+ config['in_channels'] = 48
702
+ elif stack_mask:
703
+ print('[DEBUG] use stacking mask')
704
+ config['in_channels'] = 36
705
+
706
+ from diffusers.utils import WEIGHTS_NAME
707
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
708
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
709
+
710
+ if "dict_mapping" in transformer_additional_kwargs.keys():
711
+ for key in transformer_additional_kwargs["dict_mapping"]:
712
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
713
+
714
+ if low_cpu_mem_usage:
715
+ try:
716
+ import re
717
+
718
+ from diffusers.models.modeling_utils import \
719
+ load_model_dict_into_meta
720
+ from diffusers.utils import is_accelerate_available
721
+ if is_accelerate_available():
722
+ import accelerate
723
+
724
+ # Instantiate model with empty weights
725
+ with accelerate.init_empty_weights():
726
+ model = cls.from_config(config, **transformer_additional_kwargs)
727
+
728
+ param_device = "cpu"
729
+ if os.path.exists(model_file):
730
+ state_dict = torch.load(model_file, map_location="cpu")
731
+ elif os.path.exists(model_file_safetensors):
732
+ from safetensors.torch import load_file, safe_open
733
+ state_dict = load_file(model_file_safetensors)
734
+ else:
735
+ from safetensors.torch import load_file, safe_open
736
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
737
+ state_dict = {}
738
+ for _model_file_safetensors in model_files_safetensors:
739
+ _state_dict = load_file(_model_file_safetensors)
740
+ for key in _state_dict:
741
+ state_dict[key] = _state_dict[key]
742
+ model._convert_deprecated_attention_blocks(state_dict)
743
+ # move the params from meta device to cpu
744
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
745
+ if len(missing_keys) > 0:
746
+ raise ValueError(
747
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
748
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
749
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
750
+ " those weights or else make sure your checkpoint file is correct."
751
+ )
752
+
753
+ unexpected_keys = load_model_dict_into_meta(
754
+ model,
755
+ state_dict,
756
+ device=param_device,
757
+ dtype=torch_dtype,
758
+ model_name_or_path=pretrained_model_path,
759
+ )
760
+
761
+ if cls._keys_to_ignore_on_load_unexpected is not None:
762
+ for pat in cls._keys_to_ignore_on_load_unexpected:
763
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
764
+
765
+ if len(unexpected_keys) > 0:
766
+ print(
767
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
768
+ )
769
+ return model
770
+ except Exception as e:
771
+ print(
772
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
773
+ )
774
+
775
+ model = cls.from_config(config, **transformer_additional_kwargs)
776
+ if os.path.exists(model_file):
777
+ state_dict = torch.load(model_file, map_location="cpu")
778
+ elif os.path.exists(model_file_safetensors):
779
+ from safetensors.torch import load_file, safe_open
780
+ state_dict = load_file(model_file_safetensors)
781
+ else:
782
+ from safetensors.torch import load_file, safe_open
783
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
784
+ state_dict = {}
785
+ for _model_file_safetensors in model_files_safetensors:
786
+ _state_dict = load_file(_model_file_safetensors)
787
+ for key in _state_dict:
788
+ state_dict[key] = _state_dict[key]
789
+
790
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
791
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
792
+ if len(new_shape) == 5:
793
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
794
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
795
+ elif len(new_shape) == 2:
796
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
797
+ if use_vae_mask:
798
+ print('[DEBUG] patch_embed.proj.weight size does not match due to vae-encoded mask')
799
+ latent_ch = 16
800
+ feat_scale = 8
801
+ feat_dim = int(latent_ch * feat_scale)
802
+ old_total_dim = state_dict['patch_embed.proj.weight'].size(1)
803
+ new_total_dim = model.state_dict()['patch_embed.proj.weight'].size(1)
804
+ model.state_dict()['patch_embed.proj.weight'][:, :feat_dim] = state_dict['patch_embed.proj.weight'][:, :feat_dim]
805
+ model.state_dict()['patch_embed.proj.weight'][:, -feat_dim:] = state_dict['patch_embed.proj.weight'][:, -feat_dim:]
806
+ for i in range(feat_dim, new_total_dim - feat_dim, feat_scale):
807
+ model.state_dict()['patch_embed.proj.weight'][:, i:i+feat_scale] = state_dict['patch_embed.proj.weight'][:, feat_dim:-feat_dim]
808
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
809
+ else:
810
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
811
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
812
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
813
+ else:
814
+ model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
815
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
816
+ else:
817
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
818
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
819
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
820
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
821
+ else:
822
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
823
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
824
+
825
+ tmp_state_dict = {}
826
+ for key in state_dict:
827
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
828
+ tmp_state_dict[key] = state_dict[key]
829
+ else:
830
+ print(key, "Size don't match, skip")
831
+
832
+ state_dict = tmp_state_dict
833
+
834
+ m, u = model.load_state_dict(state_dict, strict=False)
835
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
836
+ print(m)
837
+
838
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
839
+ print(f"### All Parameters: {sum(params) / 1e6} M")
840
+
841
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
842
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
843
+
844
+ model = model.to(torch_dtype)
845
+ return model
diffusers/cogvideox_vae.py ADDED
@@ -0,0 +1,1675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import json
23
+ import os
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
27
+ from diffusers.utils import logging
28
+ from diffusers.utils.accelerate_utils import apply_forward_hook
29
+ from diffusers.models.activations import get_activation
30
+ from diffusers.models.downsampling import CogVideoXDownsample3D
31
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.upsampling import CogVideoXUpsample3D
34
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ class CogVideoXSafeConv3d(nn.Conv3d):
41
+ r"""
42
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
43
+ """
44
+
45
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
46
+ memory_count = (
47
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
48
+ )
49
+
50
+ # Set to 2GB, suitable for CuDNN
51
+ if memory_count > 2:
52
+ kernel_size = self.kernel_size[0]
53
+ part_num = int(memory_count / 2) + 1
54
+ input_chunks = torch.chunk(input, part_num, dim=2)
55
+
56
+ if kernel_size > 1:
57
+ input_chunks = [input_chunks[0]] + [
58
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
59
+ for i in range(1, len(input_chunks))
60
+ ]
61
+
62
+ output_chunks = []
63
+ for input_chunk in input_chunks:
64
+ output_chunks.append(super().forward(input_chunk))
65
+ output = torch.cat(output_chunks, dim=2)
66
+ return output
67
+ else:
68
+ return super().forward(input)
69
+
70
+
71
+ class CogVideoXCausalConv3d(nn.Module):
72
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
73
+
74
+ Args:
75
+ in_channels (`int`): Number of channels in the input tensor.
76
+ out_channels (`int`): Number of output channels produced by the convolution.
77
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
78
+ stride (`int`, defaults to `1`): Stride of the convolution.
79
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
80
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: int,
87
+ kernel_size: Union[int, Tuple[int, int, int]],
88
+ stride: int = 1,
89
+ dilation: int = 1,
90
+ pad_mode: str = "constant",
91
+ ):
92
+ super().__init__()
93
+
94
+ if isinstance(kernel_size, int):
95
+ kernel_size = (kernel_size,) * 3
96
+
97
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
98
+
99
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
100
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
101
+ time_pad = time_kernel_size - 1
102
+ height_pad = (height_kernel_size - 1) // 2
103
+ width_pad = (width_kernel_size - 1) // 2
104
+
105
+ self.pad_mode = pad_mode
106
+ self.height_pad = height_pad
107
+ self.width_pad = width_pad
108
+ self.time_pad = time_pad
109
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
110
+
111
+ self.temporal_dim = 2
112
+ self.time_kernel_size = time_kernel_size
113
+
114
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
115
+ dilation = (dilation, 1, 1)
116
+ self.conv = CogVideoXSafeConv3d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=kernel_size,
120
+ stride=stride,
121
+ dilation=dilation,
122
+ )
123
+
124
+ def fake_context_parallel_forward(
125
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
126
+ ) -> torch.Tensor:
127
+ if self.pad_mode == "replicate":
128
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
129
+ else:
130
+ kernel_size = self.time_kernel_size
131
+ if kernel_size > 1:
132
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
133
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
134
+ return inputs
135
+
136
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
137
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
138
+
139
+ if self.pad_mode == "replicate":
140
+ conv_cache = None
141
+ else:
142
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
143
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
144
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
145
+
146
+ output = self.conv(inputs)
147
+ return output, conv_cache
148
+
149
+
150
+ class CogVideoXSpatialNorm3D(nn.Module):
151
+ r"""
152
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
153
+ to 3D-video like data.
154
+
155
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
156
+
157
+ Args:
158
+ f_channels (`int`):
159
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
160
+ zq_channels (`int`):
161
+ The number of channels for the quantized vector as described in the paper.
162
+ groups (`int`):
163
+ Number of groups to separate the channels into for group normalization.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ f_channels: int,
169
+ zq_channels: int,
170
+ groups: int = 32,
171
+ ):
172
+ super().__init__()
173
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
174
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
175
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
176
+
177
+ def forward(
178
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
179
+ ) -> torch.Tensor:
180
+ new_conv_cache = {}
181
+ conv_cache = conv_cache or {}
182
+
183
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
184
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
185
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
186
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
187
+ z_first = F.interpolate(z_first, size=f_first_size)
188
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
189
+ zq = torch.cat([z_first, z_rest], dim=2)
190
+ else:
191
+ zq = F.interpolate(zq, size=f.shape[-3:])
192
+
193
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
194
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
195
+
196
+ norm_f = self.norm_layer(f)
197
+ new_f = norm_f * conv_y + conv_b
198
+ return new_f, new_conv_cache
199
+
200
+
201
+ class CogVideoXUpsample3D(nn.Module):
202
+ r"""
203
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
204
+
205
+ Args:
206
+ in_channels (`int`):
207
+ Number of channels in the input image.
208
+ out_channels (`int`):
209
+ Number of channels produced by the convolution.
210
+ kernel_size (`int`, defaults to `3`):
211
+ Size of the convolving kernel.
212
+ stride (`int`, defaults to `1`):
213
+ Stride of the convolution.
214
+ padding (`int`, defaults to `1`):
215
+ Padding added to all four sides of the input.
216
+ compress_time (`bool`, defaults to `False`):
217
+ Whether or not to compress the time dimension.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ in_channels: int,
223
+ out_channels: int,
224
+ kernel_size: int = 3,
225
+ stride: int = 1,
226
+ padding: int = 1,
227
+ compress_time: bool = False,
228
+ ) -> None:
229
+ super().__init__()
230
+
231
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
232
+ self.compress_time = compress_time
233
+
234
+ self.auto_split_process = True
235
+ self.first_frame_flag = False
236
+
237
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
238
+ if self.compress_time:
239
+ if self.auto_split_process:
240
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
241
+ # split first frame
242
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
243
+
244
+ x_first = F.interpolate(x_first, scale_factor=2.0)
245
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
246
+ x_first = x_first[:, :, None, :, :]
247
+ inputs = torch.cat([x_first, x_rest], dim=2)
248
+ elif inputs.shape[2] > 1:
249
+ inputs = F.interpolate(inputs, scale_factor=2.0)
250
+ else:
251
+ inputs = inputs.squeeze(2)
252
+ inputs = F.interpolate(inputs, scale_factor=2.0)
253
+ inputs = inputs[:, :, None, :, :]
254
+ else:
255
+ if self.first_frame_flag:
256
+ inputs = inputs.squeeze(2)
257
+ inputs = F.interpolate(inputs, scale_factor=2.0)
258
+ inputs = inputs[:, :, None, :, :]
259
+ else:
260
+ inputs = F.interpolate(inputs, scale_factor=2.0)
261
+ else:
262
+ # only interpolate 2D
263
+ b, c, t, h, w = inputs.shape
264
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
265
+ inputs = F.interpolate(inputs, scale_factor=2.0)
266
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
267
+
268
+ b, c, t, h, w = inputs.shape
269
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
270
+ inputs = self.conv(inputs)
271
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
272
+
273
+ return inputs
274
+
275
+
276
+ class CogVideoXResnetBlock3D(nn.Module):
277
+ r"""
278
+ A 3D ResNet block used in the CogVideoX model.
279
+
280
+ Args:
281
+ in_channels (`int`):
282
+ Number of input channels.
283
+ out_channels (`int`, *optional*):
284
+ Number of output channels. If None, defaults to `in_channels`.
285
+ dropout (`float`, defaults to `0.0`):
286
+ Dropout rate.
287
+ temb_channels (`int`, defaults to `512`):
288
+ Number of time embedding channels.
289
+ groups (`int`, defaults to `32`):
290
+ Number of groups to separate the channels into for group normalization.
291
+ eps (`float`, defaults to `1e-6`):
292
+ Epsilon value for normalization layers.
293
+ non_linearity (`str`, defaults to `"swish"`):
294
+ Activation function to use.
295
+ conv_shortcut (bool, defaults to `False`):
296
+ Whether or not to use a convolution shortcut.
297
+ spatial_norm_dim (`int`, *optional*):
298
+ The dimension to use for spatial norm if it is to be used instead of group norm.
299
+ pad_mode (str, defaults to `"first"`):
300
+ Padding mode.
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ in_channels: int,
306
+ out_channels: Optional[int] = None,
307
+ dropout: float = 0.0,
308
+ temb_channels: int = 512,
309
+ groups: int = 32,
310
+ eps: float = 1e-6,
311
+ non_linearity: str = "swish",
312
+ conv_shortcut: bool = False,
313
+ spatial_norm_dim: Optional[int] = None,
314
+ pad_mode: str = "first",
315
+ ):
316
+ super().__init__()
317
+
318
+ out_channels = out_channels or in_channels
319
+
320
+ self.in_channels = in_channels
321
+ self.out_channels = out_channels
322
+ self.nonlinearity = get_activation(non_linearity)
323
+ self.use_conv_shortcut = conv_shortcut
324
+ self.spatial_norm_dim = spatial_norm_dim
325
+
326
+ if spatial_norm_dim is None:
327
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
328
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
329
+ else:
330
+ self.norm1 = CogVideoXSpatialNorm3D(
331
+ f_channels=in_channels,
332
+ zq_channels=spatial_norm_dim,
333
+ groups=groups,
334
+ )
335
+ self.norm2 = CogVideoXSpatialNorm3D(
336
+ f_channels=out_channels,
337
+ zq_channels=spatial_norm_dim,
338
+ groups=groups,
339
+ )
340
+
341
+ self.conv1 = CogVideoXCausalConv3d(
342
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
343
+ )
344
+
345
+ if temb_channels > 0:
346
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
347
+
348
+ self.dropout = nn.Dropout(dropout)
349
+ self.conv2 = CogVideoXCausalConv3d(
350
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
351
+ )
352
+
353
+ if self.in_channels != self.out_channels:
354
+ if self.use_conv_shortcut:
355
+ self.conv_shortcut = CogVideoXCausalConv3d(
356
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
357
+ )
358
+ else:
359
+ self.conv_shortcut = CogVideoXSafeConv3d(
360
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
361
+ )
362
+
363
+ def forward(
364
+ self,
365
+ inputs: torch.Tensor,
366
+ temb: Optional[torch.Tensor] = None,
367
+ zq: Optional[torch.Tensor] = None,
368
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
369
+ ) -> torch.Tensor:
370
+ new_conv_cache = {}
371
+ conv_cache = conv_cache or {}
372
+
373
+ hidden_states = inputs
374
+
375
+ if zq is not None:
376
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
377
+ else:
378
+ hidden_states = self.norm1(hidden_states)
379
+
380
+ hidden_states = self.nonlinearity(hidden_states)
381
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
382
+
383
+ if temb is not None:
384
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
385
+
386
+ if zq is not None:
387
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
388
+ else:
389
+ hidden_states = self.norm2(hidden_states)
390
+
391
+ hidden_states = self.nonlinearity(hidden_states)
392
+ hidden_states = self.dropout(hidden_states)
393
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
394
+
395
+ if self.in_channels != self.out_channels:
396
+ if self.use_conv_shortcut:
397
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
398
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
399
+ )
400
+ else:
401
+ inputs = self.conv_shortcut(inputs)
402
+
403
+ hidden_states = hidden_states + inputs
404
+ return hidden_states, new_conv_cache
405
+
406
+
407
+ class CogVideoXDownBlock3D(nn.Module):
408
+ r"""
409
+ A downsampling block used in the CogVideoX model.
410
+
411
+ Args:
412
+ in_channels (`int`):
413
+ Number of input channels.
414
+ out_channels (`int`, *optional*):
415
+ Number of output channels. If None, defaults to `in_channels`.
416
+ temb_channels (`int`, defaults to `512`):
417
+ Number of time embedding channels.
418
+ num_layers (`int`, defaults to `1`):
419
+ Number of resnet layers.
420
+ dropout (`float`, defaults to `0.0`):
421
+ Dropout rate.
422
+ resnet_eps (`float`, defaults to `1e-6`):
423
+ Epsilon value for normalization layers.
424
+ resnet_act_fn (`str`, defaults to `"swish"`):
425
+ Activation function to use.
426
+ resnet_groups (`int`, defaults to `32`):
427
+ Number of groups to separate the channels into for group normalization.
428
+ add_downsample (`bool`, defaults to `True`):
429
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
430
+ compress_time (`bool`, defaults to `False`):
431
+ Whether or not to downsample across temporal dimension.
432
+ pad_mode (str, defaults to `"first"`):
433
+ Padding mode.
434
+ """
435
+
436
+ _supports_gradient_checkpointing = True
437
+
438
+ def __init__(
439
+ self,
440
+ in_channels: int,
441
+ out_channels: int,
442
+ temb_channels: int,
443
+ dropout: float = 0.0,
444
+ num_layers: int = 1,
445
+ resnet_eps: float = 1e-6,
446
+ resnet_act_fn: str = "swish",
447
+ resnet_groups: int = 32,
448
+ add_downsample: bool = True,
449
+ downsample_padding: int = 0,
450
+ compress_time: bool = False,
451
+ pad_mode: str = "first",
452
+ ):
453
+ super().__init__()
454
+
455
+ resnets = []
456
+ for i in range(num_layers):
457
+ in_channel = in_channels if i == 0 else out_channels
458
+ resnets.append(
459
+ CogVideoXResnetBlock3D(
460
+ in_channels=in_channel,
461
+ out_channels=out_channels,
462
+ dropout=dropout,
463
+ temb_channels=temb_channels,
464
+ groups=resnet_groups,
465
+ eps=resnet_eps,
466
+ non_linearity=resnet_act_fn,
467
+ pad_mode=pad_mode,
468
+ )
469
+ )
470
+
471
+ self.resnets = nn.ModuleList(resnets)
472
+ self.downsamplers = None
473
+
474
+ if add_downsample:
475
+ self.downsamplers = nn.ModuleList(
476
+ [
477
+ CogVideoXDownsample3D(
478
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
479
+ )
480
+ ]
481
+ )
482
+
483
+ self.gradient_checkpointing = False
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ temb: Optional[torch.Tensor] = None,
489
+ zq: Optional[torch.Tensor] = None,
490
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
491
+ ) -> torch.Tensor:
492
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
493
+
494
+ new_conv_cache = {}
495
+ conv_cache = conv_cache or {}
496
+
497
+ for i, resnet in enumerate(self.resnets):
498
+ conv_cache_key = f"resnet_{i}"
499
+
500
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
501
+
502
+ def create_custom_forward(module):
503
+ def create_forward(*inputs):
504
+ return module(*inputs)
505
+
506
+ return create_forward
507
+
508
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
509
+ create_custom_forward(resnet),
510
+ hidden_states,
511
+ temb,
512
+ zq,
513
+ conv_cache.get(conv_cache_key),
514
+ )
515
+ else:
516
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
517
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
518
+ )
519
+
520
+ if self.downsamplers is not None:
521
+ for downsampler in self.downsamplers:
522
+ hidden_states = downsampler(hidden_states)
523
+
524
+ return hidden_states, new_conv_cache
525
+
526
+
527
+ class CogVideoXMidBlock3D(nn.Module):
528
+ r"""
529
+ A middle block used in the CogVideoX model.
530
+
531
+ Args:
532
+ in_channels (`int`):
533
+ Number of input channels.
534
+ temb_channels (`int`, defaults to `512`):
535
+ Number of time embedding channels.
536
+ dropout (`float`, defaults to `0.0`):
537
+ Dropout rate.
538
+ num_layers (`int`, defaults to `1`):
539
+ Number of resnet layers.
540
+ resnet_eps (`float`, defaults to `1e-6`):
541
+ Epsilon value for normalization layers.
542
+ resnet_act_fn (`str`, defaults to `"swish"`):
543
+ Activation function to use.
544
+ resnet_groups (`int`, defaults to `32`):
545
+ Number of groups to separate the channels into for group normalization.
546
+ spatial_norm_dim (`int`, *optional*):
547
+ The dimension to use for spatial norm if it is to be used instead of group norm.
548
+ pad_mode (str, defaults to `"first"`):
549
+ Padding mode.
550
+ """
551
+
552
+ _supports_gradient_checkpointing = True
553
+
554
+ def __init__(
555
+ self,
556
+ in_channels: int,
557
+ temb_channels: int,
558
+ dropout: float = 0.0,
559
+ num_layers: int = 1,
560
+ resnet_eps: float = 1e-6,
561
+ resnet_act_fn: str = "swish",
562
+ resnet_groups: int = 32,
563
+ spatial_norm_dim: Optional[int] = None,
564
+ pad_mode: str = "first",
565
+ ):
566
+ super().__init__()
567
+
568
+ resnets = []
569
+ for _ in range(num_layers):
570
+ resnets.append(
571
+ CogVideoXResnetBlock3D(
572
+ in_channels=in_channels,
573
+ out_channels=in_channels,
574
+ dropout=dropout,
575
+ temb_channels=temb_channels,
576
+ groups=resnet_groups,
577
+ eps=resnet_eps,
578
+ spatial_norm_dim=spatial_norm_dim,
579
+ non_linearity=resnet_act_fn,
580
+ pad_mode=pad_mode,
581
+ )
582
+ )
583
+ self.resnets = nn.ModuleList(resnets)
584
+
585
+ self.gradient_checkpointing = False
586
+
587
+ def forward(
588
+ self,
589
+ hidden_states: torch.Tensor,
590
+ temb: Optional[torch.Tensor] = None,
591
+ zq: Optional[torch.Tensor] = None,
592
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
593
+ ) -> torch.Tensor:
594
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
595
+
596
+ new_conv_cache = {}
597
+ conv_cache = conv_cache or {}
598
+
599
+ for i, resnet in enumerate(self.resnets):
600
+ conv_cache_key = f"resnet_{i}"
601
+
602
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
603
+
604
+ def create_custom_forward(module):
605
+ def create_forward(*inputs):
606
+ return module(*inputs)
607
+
608
+ return create_forward
609
+
610
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
611
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
612
+ )
613
+ else:
614
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
615
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
616
+ )
617
+
618
+ return hidden_states, new_conv_cache
619
+
620
+
621
+ class CogVideoXUpBlock3D(nn.Module):
622
+ r"""
623
+ An upsampling block used in the CogVideoX model.
624
+
625
+ Args:
626
+ in_channels (`int`):
627
+ Number of input channels.
628
+ out_channels (`int`, *optional*):
629
+ Number of output channels. If None, defaults to `in_channels`.
630
+ temb_channels (`int`, defaults to `512`):
631
+ Number of time embedding channels.
632
+ dropout (`float`, defaults to `0.0`):
633
+ Dropout rate.
634
+ num_layers (`int`, defaults to `1`):
635
+ Number of resnet layers.
636
+ resnet_eps (`float`, defaults to `1e-6`):
637
+ Epsilon value for normalization layers.
638
+ resnet_act_fn (`str`, defaults to `"swish"`):
639
+ Activation function to use.
640
+ resnet_groups (`int`, defaults to `32`):
641
+ Number of groups to separate the channels into for group normalization.
642
+ spatial_norm_dim (`int`, defaults to `16`):
643
+ The dimension to use for spatial norm if it is to be used instead of group norm.
644
+ add_upsample (`bool`, defaults to `True`):
645
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
646
+ compress_time (`bool`, defaults to `False`):
647
+ Whether or not to downsample across temporal dimension.
648
+ pad_mode (str, defaults to `"first"`):
649
+ Padding mode.
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ in_channels: int,
655
+ out_channels: int,
656
+ temb_channels: int,
657
+ dropout: float = 0.0,
658
+ num_layers: int = 1,
659
+ resnet_eps: float = 1e-6,
660
+ resnet_act_fn: str = "swish",
661
+ resnet_groups: int = 32,
662
+ spatial_norm_dim: int = 16,
663
+ add_upsample: bool = True,
664
+ upsample_padding: int = 1,
665
+ compress_time: bool = False,
666
+ pad_mode: str = "first",
667
+ ):
668
+ super().__init__()
669
+
670
+ resnets = []
671
+ for i in range(num_layers):
672
+ in_channel = in_channels if i == 0 else out_channels
673
+ resnets.append(
674
+ CogVideoXResnetBlock3D(
675
+ in_channels=in_channel,
676
+ out_channels=out_channels,
677
+ dropout=dropout,
678
+ temb_channels=temb_channels,
679
+ groups=resnet_groups,
680
+ eps=resnet_eps,
681
+ non_linearity=resnet_act_fn,
682
+ spatial_norm_dim=spatial_norm_dim,
683
+ pad_mode=pad_mode,
684
+ )
685
+ )
686
+
687
+ self.resnets = nn.ModuleList(resnets)
688
+ self.upsamplers = None
689
+
690
+ if add_upsample:
691
+ self.upsamplers = nn.ModuleList(
692
+ [
693
+ CogVideoXUpsample3D(
694
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
695
+ )
696
+ ]
697
+ )
698
+
699
+ self.gradient_checkpointing = False
700
+
701
+ def forward(
702
+ self,
703
+ hidden_states: torch.Tensor,
704
+ temb: Optional[torch.Tensor] = None,
705
+ zq: Optional[torch.Tensor] = None,
706
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
707
+ ) -> torch.Tensor:
708
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
709
+
710
+ new_conv_cache = {}
711
+ conv_cache = conv_cache or {}
712
+
713
+ for i, resnet in enumerate(self.resnets):
714
+ conv_cache_key = f"resnet_{i}"
715
+
716
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
717
+
718
+ def create_custom_forward(module):
719
+ def create_forward(*inputs):
720
+ return module(*inputs)
721
+
722
+ return create_forward
723
+
724
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
725
+ create_custom_forward(resnet),
726
+ hidden_states,
727
+ temb,
728
+ zq,
729
+ conv_cache.get(conv_cache_key),
730
+ )
731
+ else:
732
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
733
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
734
+ )
735
+
736
+ if self.upsamplers is not None:
737
+ for upsampler in self.upsamplers:
738
+ hidden_states = upsampler(hidden_states)
739
+
740
+ return hidden_states, new_conv_cache
741
+
742
+
743
+ class CogVideoXEncoder3D(nn.Module):
744
+ r"""
745
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
746
+
747
+ Args:
748
+ in_channels (`int`, *optional*, defaults to 3):
749
+ The number of input channels.
750
+ out_channels (`int`, *optional*, defaults to 3):
751
+ The number of output channels.
752
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
753
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
754
+ options.
755
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
756
+ The number of output channels for each block.
757
+ act_fn (`str`, *optional*, defaults to `"silu"`):
758
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
759
+ layers_per_block (`int`, *optional*, defaults to 2):
760
+ The number of layers per block.
761
+ norm_num_groups (`int`, *optional*, defaults to 32):
762
+ The number of groups for normalization.
763
+ """
764
+
765
+ _supports_gradient_checkpointing = True
766
+
767
+ def __init__(
768
+ self,
769
+ in_channels: int = 3,
770
+ out_channels: int = 16,
771
+ down_block_types: Tuple[str, ...] = (
772
+ "CogVideoXDownBlock3D",
773
+ "CogVideoXDownBlock3D",
774
+ "CogVideoXDownBlock3D",
775
+ "CogVideoXDownBlock3D",
776
+ ),
777
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
778
+ layers_per_block: int = 3,
779
+ act_fn: str = "silu",
780
+ norm_eps: float = 1e-6,
781
+ norm_num_groups: int = 32,
782
+ dropout: float = 0.0,
783
+ pad_mode: str = "first",
784
+ temporal_compression_ratio: float = 4,
785
+ ):
786
+ super().__init__()
787
+
788
+ # log2 of temporal_compress_times
789
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
790
+
791
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
792
+ self.down_blocks = nn.ModuleList([])
793
+
794
+ # down blocks
795
+ output_channel = block_out_channels[0]
796
+ for i, down_block_type in enumerate(down_block_types):
797
+ input_channel = output_channel
798
+ output_channel = block_out_channels[i]
799
+ is_final_block = i == len(block_out_channels) - 1
800
+ compress_time = i < temporal_compress_level
801
+
802
+ if down_block_type == "CogVideoXDownBlock3D":
803
+ down_block = CogVideoXDownBlock3D(
804
+ in_channels=input_channel,
805
+ out_channels=output_channel,
806
+ temb_channels=0,
807
+ dropout=dropout,
808
+ num_layers=layers_per_block,
809
+ resnet_eps=norm_eps,
810
+ resnet_act_fn=act_fn,
811
+ resnet_groups=norm_num_groups,
812
+ add_downsample=not is_final_block,
813
+ compress_time=compress_time,
814
+ )
815
+ else:
816
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
817
+
818
+ self.down_blocks.append(down_block)
819
+
820
+ # mid block
821
+ self.mid_block = CogVideoXMidBlock3D(
822
+ in_channels=block_out_channels[-1],
823
+ temb_channels=0,
824
+ dropout=dropout,
825
+ num_layers=2,
826
+ resnet_eps=norm_eps,
827
+ resnet_act_fn=act_fn,
828
+ resnet_groups=norm_num_groups,
829
+ pad_mode=pad_mode,
830
+ )
831
+
832
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
833
+ self.conv_act = nn.SiLU()
834
+ self.conv_out = CogVideoXCausalConv3d(
835
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
836
+ )
837
+
838
+ self.gradient_checkpointing = False
839
+
840
+ def forward(
841
+ self,
842
+ sample: torch.Tensor,
843
+ temb: Optional[torch.Tensor] = None,
844
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
845
+ ) -> torch.Tensor:
846
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
847
+
848
+ new_conv_cache = {}
849
+ conv_cache = conv_cache or {}
850
+
851
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
+
853
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
854
+
855
+ def create_custom_forward(module):
856
+ def custom_forward(*inputs):
857
+ return module(*inputs)
858
+
859
+ return custom_forward
860
+
861
+ # 1. Down
862
+ for i, down_block in enumerate(self.down_blocks):
863
+ conv_cache_key = f"down_block_{i}"
864
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
865
+ create_custom_forward(down_block),
866
+ hidden_states,
867
+ temb,
868
+ None,
869
+ conv_cache.get(conv_cache_key),
870
+ )
871
+
872
+ # 2. Mid
873
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
874
+ create_custom_forward(self.mid_block),
875
+ hidden_states,
876
+ temb,
877
+ None,
878
+ conv_cache.get("mid_block"),
879
+ )
880
+ else:
881
+ # 1. Down
882
+ for i, down_block in enumerate(self.down_blocks):
883
+ conv_cache_key = f"down_block_{i}"
884
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
885
+ hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
886
+ )
887
+
888
+ # 2. Mid
889
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
890
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
891
+ )
892
+
893
+ # 3. Post-process
894
+ hidden_states = self.norm_out(hidden_states)
895
+ hidden_states = self.conv_act(hidden_states)
896
+
897
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
898
+
899
+ return hidden_states, new_conv_cache
900
+
901
+
902
+ class CogVideoXDecoder3D(nn.Module):
903
+ r"""
904
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
905
+ sample.
906
+
907
+ Args:
908
+ in_channels (`int`, *optional*, defaults to 3):
909
+ The number of input channels.
910
+ out_channels (`int`, *optional*, defaults to 3):
911
+ The number of output channels.
912
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
913
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
914
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
915
+ The number of output channels for each block.
916
+ act_fn (`str`, *optional*, defaults to `"silu"`):
917
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
918
+ layers_per_block (`int`, *optional*, defaults to 2):
919
+ The number of layers per block.
920
+ norm_num_groups (`int`, *optional*, defaults to 32):
921
+ The number of groups for normalization.
922
+ """
923
+
924
+ _supports_gradient_checkpointing = True
925
+
926
+ def __init__(
927
+ self,
928
+ in_channels: int = 16,
929
+ out_channels: int = 3,
930
+ up_block_types: Tuple[str, ...] = (
931
+ "CogVideoXUpBlock3D",
932
+ "CogVideoXUpBlock3D",
933
+ "CogVideoXUpBlock3D",
934
+ "CogVideoXUpBlock3D",
935
+ ),
936
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
937
+ layers_per_block: int = 3,
938
+ act_fn: str = "silu",
939
+ norm_eps: float = 1e-6,
940
+ norm_num_groups: int = 32,
941
+ dropout: float = 0.0,
942
+ pad_mode: str = "first",
943
+ temporal_compression_ratio: float = 4,
944
+ ):
945
+ super().__init__()
946
+
947
+ reversed_block_out_channels = list(reversed(block_out_channels))
948
+
949
+ self.conv_in = CogVideoXCausalConv3d(
950
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
951
+ )
952
+
953
+ # mid block
954
+ self.mid_block = CogVideoXMidBlock3D(
955
+ in_channels=reversed_block_out_channels[0],
956
+ temb_channels=0,
957
+ num_layers=2,
958
+ resnet_eps=norm_eps,
959
+ resnet_act_fn=act_fn,
960
+ resnet_groups=norm_num_groups,
961
+ spatial_norm_dim=in_channels,
962
+ pad_mode=pad_mode,
963
+ )
964
+
965
+ # up blocks
966
+ self.up_blocks = nn.ModuleList([])
967
+
968
+ output_channel = reversed_block_out_channels[0]
969
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
970
+
971
+ for i, up_block_type in enumerate(up_block_types):
972
+ prev_output_channel = output_channel
973
+ output_channel = reversed_block_out_channels[i]
974
+ is_final_block = i == len(block_out_channels) - 1
975
+ compress_time = i < temporal_compress_level
976
+
977
+ if up_block_type == "CogVideoXUpBlock3D":
978
+ up_block = CogVideoXUpBlock3D(
979
+ in_channels=prev_output_channel,
980
+ out_channels=output_channel,
981
+ temb_channels=0,
982
+ dropout=dropout,
983
+ num_layers=layers_per_block + 1,
984
+ resnet_eps=norm_eps,
985
+ resnet_act_fn=act_fn,
986
+ resnet_groups=norm_num_groups,
987
+ spatial_norm_dim=in_channels,
988
+ add_upsample=not is_final_block,
989
+ compress_time=compress_time,
990
+ pad_mode=pad_mode,
991
+ )
992
+ prev_output_channel = output_channel
993
+ else:
994
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
995
+
996
+ self.up_blocks.append(up_block)
997
+
998
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
999
+ self.conv_act = nn.SiLU()
1000
+ self.conv_out = CogVideoXCausalConv3d(
1001
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
1002
+ )
1003
+
1004
+ self.gradient_checkpointing = False
1005
+
1006
+ def forward(
1007
+ self,
1008
+ sample: torch.Tensor,
1009
+ temb: Optional[torch.Tensor] = None,
1010
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
1011
+ ) -> torch.Tensor:
1012
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
1013
+
1014
+ new_conv_cache = {}
1015
+ conv_cache = conv_cache or {}
1016
+
1017
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
1018
+
1019
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1020
+
1021
+ def create_custom_forward(module):
1022
+ def custom_forward(*inputs):
1023
+ return module(*inputs)
1024
+
1025
+ return custom_forward
1026
+
1027
+ # 1. Mid
1028
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
1029
+ create_custom_forward(self.mid_block),
1030
+ hidden_states,
1031
+ temb,
1032
+ sample,
1033
+ conv_cache.get("mid_block"),
1034
+ )
1035
+
1036
+ # 2. Up
1037
+ for i, up_block in enumerate(self.up_blocks):
1038
+ conv_cache_key = f"up_block_{i}"
1039
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
1040
+ create_custom_forward(up_block),
1041
+ hidden_states,
1042
+ temb,
1043
+ sample,
1044
+ conv_cache.get(conv_cache_key),
1045
+ )
1046
+ else:
1047
+ # 1. Mid
1048
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
1049
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
1050
+ )
1051
+
1052
+ # 2. Up
1053
+ for i, up_block in enumerate(self.up_blocks):
1054
+ conv_cache_key = f"up_block_{i}"
1055
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
1056
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
1057
+ )
1058
+
1059
+ # 3. Post-process
1060
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
1061
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
1062
+ )
1063
+ hidden_states = self.conv_act(hidden_states)
1064
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
1065
+
1066
+ return hidden_states, new_conv_cache
1067
+
1068
+
1069
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1070
+ r"""
1071
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
1072
+ [CogVideoX](https://github.com/THUDM/CogVideo).
1073
+
1074
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1075
+ for all models (such as downloading or saving).
1076
+
1077
+ Parameters:
1078
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1079
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1080
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1081
+ Tuple of downsample block types.
1082
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1083
+ Tuple of upsample block types.
1084
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1085
+ Tuple of block output channels.
1086
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1087
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1088
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
1089
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
1090
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
1091
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
1092
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
1093
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
1094
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
1095
+ force_upcast (`bool`, *optional*, default to `True`):
1096
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1097
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
1098
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1099
+ """
1100
+
1101
+ _supports_gradient_checkpointing = True
1102
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
1103
+
1104
+ @register_to_config
1105
+ def __init__(
1106
+ self,
1107
+ in_channels: int = 3,
1108
+ out_channels: int = 3,
1109
+ down_block_types: Tuple[str] = (
1110
+ "CogVideoXDownBlock3D",
1111
+ "CogVideoXDownBlock3D",
1112
+ "CogVideoXDownBlock3D",
1113
+ "CogVideoXDownBlock3D",
1114
+ ),
1115
+ up_block_types: Tuple[str] = (
1116
+ "CogVideoXUpBlock3D",
1117
+ "CogVideoXUpBlock3D",
1118
+ "CogVideoXUpBlock3D",
1119
+ "CogVideoXUpBlock3D",
1120
+ ),
1121
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
1122
+ latent_channels: int = 16,
1123
+ layers_per_block: int = 3,
1124
+ act_fn: str = "silu",
1125
+ norm_eps: float = 1e-6,
1126
+ norm_num_groups: int = 32,
1127
+ temporal_compression_ratio: float = 4,
1128
+ sample_height: int = 480,
1129
+ sample_width: int = 720,
1130
+ scaling_factor: float = 1.15258426,
1131
+ shift_factor: Optional[float] = None,
1132
+ latents_mean: Optional[Tuple[float]] = None,
1133
+ latents_std: Optional[Tuple[float]] = None,
1134
+ force_upcast: float = True,
1135
+ use_quant_conv: bool = False,
1136
+ use_post_quant_conv: bool = False,
1137
+ invert_scale_latents: bool = False,
1138
+ ):
1139
+ super().__init__()
1140
+
1141
+ self.encoder = CogVideoXEncoder3D(
1142
+ in_channels=in_channels,
1143
+ out_channels=latent_channels,
1144
+ down_block_types=down_block_types,
1145
+ block_out_channels=block_out_channels,
1146
+ layers_per_block=layers_per_block,
1147
+ act_fn=act_fn,
1148
+ norm_eps=norm_eps,
1149
+ norm_num_groups=norm_num_groups,
1150
+ temporal_compression_ratio=temporal_compression_ratio,
1151
+ )
1152
+ self.decoder = CogVideoXDecoder3D(
1153
+ in_channels=latent_channels,
1154
+ out_channels=out_channels,
1155
+ up_block_types=up_block_types,
1156
+ block_out_channels=block_out_channels,
1157
+ layers_per_block=layers_per_block,
1158
+ act_fn=act_fn,
1159
+ norm_eps=norm_eps,
1160
+ norm_num_groups=norm_num_groups,
1161
+ temporal_compression_ratio=temporal_compression_ratio,
1162
+ )
1163
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
1164
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
1165
+
1166
+ self.use_slicing = False
1167
+ self.use_tiling = False
1168
+ self.auto_split_process = False
1169
+
1170
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
1171
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
1172
+ # If you decode X latent frames together, the number of output frames is:
1173
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
1174
+ #
1175
+ # Example with num_latent_frames_batch_size = 2:
1176
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
1177
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1178
+ # => 6 * 8 = 48 frames
1179
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
1180
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
1181
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1182
+ # => 1 * 9 + 5 * 8 = 49 frames
1183
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
1184
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1185
+ # number of temporal frames.
1186
+ self.num_latent_frames_batch_size = 2
1187
+ self.num_sample_frames_batch_size = 8
1188
+
1189
+ # We make the minimum height and width of sample for tiling half that of the generally supported
1190
+ self.tile_sample_min_height = sample_height // 2
1191
+ self.tile_sample_min_width = sample_width // 2
1192
+ self.tile_latent_min_height = int(
1193
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1194
+ )
1195
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1196
+
1197
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1198
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1199
+ # and so the tiling implementation has only been tested on those specific resolutions.
1200
+ self.tile_overlap_factor_height = 1 / 6
1201
+ self.tile_overlap_factor_width = 1 / 5
1202
+
1203
+ def _set_gradient_checkpointing(self, module, value=False):
1204
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1205
+ module.gradient_checkpointing = value
1206
+
1207
+ def enable_tiling(
1208
+ self,
1209
+ tile_sample_min_height: Optional[int] = None,
1210
+ tile_sample_min_width: Optional[int] = None,
1211
+ tile_overlap_factor_height: Optional[float] = None,
1212
+ tile_overlap_factor_width: Optional[float] = None,
1213
+ ) -> None:
1214
+ r"""
1215
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1216
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1217
+ processing larger images.
1218
+
1219
+ Args:
1220
+ tile_sample_min_height (`int`, *optional*):
1221
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1222
+ tile_sample_min_width (`int`, *optional*):
1223
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1224
+ tile_overlap_factor_height (`int`, *optional*):
1225
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1226
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1227
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1228
+ tile_overlap_factor_width (`int`, *optional*):
1229
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1230
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1231
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1232
+ """
1233
+ self.use_tiling = True
1234
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1235
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1236
+ self.tile_latent_min_height = int(
1237
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1238
+ )
1239
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1240
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1241
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1242
+
1243
+ def disable_tiling(self) -> None:
1244
+ r"""
1245
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1246
+ decoding in one step.
1247
+ """
1248
+ self.use_tiling = False
1249
+
1250
+ def enable_slicing(self) -> None:
1251
+ r"""
1252
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1253
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1254
+ """
1255
+ self.use_slicing = True
1256
+
1257
+ def disable_slicing(self) -> None:
1258
+ r"""
1259
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1260
+ decoding in one step.
1261
+ """
1262
+ self.use_slicing = False
1263
+
1264
+ def _set_first_frame(self):
1265
+ for name, module in self.named_modules():
1266
+ if isinstance(module, CogVideoXUpsample3D):
1267
+ module.auto_split_process = False
1268
+ module.first_frame_flag = True
1269
+
1270
+ def _set_rest_frame(self):
1271
+ for name, module in self.named_modules():
1272
+ if isinstance(module, CogVideoXUpsample3D):
1273
+ module.auto_split_process = False
1274
+ module.first_frame_flag = False
1275
+
1276
+ def enable_auto_split_process(self) -> None:
1277
+ self.auto_split_process = True
1278
+ for name, module in self.named_modules():
1279
+ if isinstance(module, CogVideoXUpsample3D):
1280
+ module.auto_split_process = True
1281
+
1282
+ def disable_auto_split_process(self) -> None:
1283
+ self.auto_split_process = False
1284
+
1285
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1286
+ batch_size, num_channels, num_frames, height, width = x.shape
1287
+
1288
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1289
+ return self.tiled_encode(x)
1290
+
1291
+ frame_batch_size = self.num_sample_frames_batch_size
1292
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1293
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1294
+ num_batches = max(num_frames // frame_batch_size, 1)
1295
+ conv_cache = None
1296
+ enc = []
1297
+
1298
+ for i in range(num_batches):
1299
+ remaining_frames = num_frames % frame_batch_size
1300
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1301
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1302
+ x_intermediate = x[:, :, start_frame:end_frame]
1303
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1304
+ if self.quant_conv is not None:
1305
+ x_intermediate = self.quant_conv(x_intermediate)
1306
+ enc.append(x_intermediate)
1307
+
1308
+ enc = torch.cat(enc, dim=2)
1309
+ return enc
1310
+
1311
+ @apply_forward_hook
1312
+ def encode(
1313
+ self, x: torch.Tensor, return_dict: bool = True
1314
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1315
+ """
1316
+ Encode a batch of images into latents.
1317
+
1318
+ Args:
1319
+ x (`torch.Tensor`): Input batch of images.
1320
+ return_dict (`bool`, *optional*, defaults to `True`):
1321
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1322
+
1323
+ Returns:
1324
+ The latent representations of the encoded videos. If `return_dict` is True, a
1325
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1326
+ """
1327
+ if self.use_slicing and x.shape[0] > 1:
1328
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1329
+ h = torch.cat(encoded_slices)
1330
+ else:
1331
+ h = self._encode(x)
1332
+
1333
+ posterior = DiagonalGaussianDistribution(h)
1334
+
1335
+ if not return_dict:
1336
+ return (posterior,)
1337
+ return AutoencoderKLOutput(latent_dist=posterior)
1338
+
1339
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1340
+ batch_size, num_channels, num_frames, height, width = z.shape
1341
+
1342
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1343
+ return self.tiled_decode(z, return_dict=return_dict)
1344
+
1345
+ if self.auto_split_process:
1346
+ frame_batch_size = self.num_latent_frames_batch_size
1347
+ num_batches = max(num_frames // frame_batch_size, 1)
1348
+ conv_cache = None
1349
+ dec = []
1350
+
1351
+ for i in range(num_batches):
1352
+ remaining_frames = num_frames % frame_batch_size
1353
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1354
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1355
+ z_intermediate = z[:, :, start_frame:end_frame]
1356
+ if self.post_quant_conv is not None:
1357
+ z_intermediate = self.post_quant_conv(z_intermediate)
1358
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1359
+ dec.append(z_intermediate)
1360
+ else:
1361
+ conv_cache = None
1362
+ start_frame = 0
1363
+ end_frame = 1
1364
+ dec = []
1365
+
1366
+ self._set_first_frame()
1367
+ z_intermediate = z[:, :, start_frame:end_frame]
1368
+ if self.post_quant_conv is not None:
1369
+ z_intermediate = self.post_quant_conv(z_intermediate)
1370
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1371
+ dec.append(z_intermediate)
1372
+
1373
+ self._set_rest_frame()
1374
+ start_frame = end_frame
1375
+ end_frame += self.num_latent_frames_batch_size
1376
+
1377
+ while start_frame < num_frames:
1378
+ z_intermediate = z[:, :, start_frame:end_frame]
1379
+ if self.post_quant_conv is not None:
1380
+ z_intermediate = self.post_quant_conv(z_intermediate)
1381
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1382
+ dec.append(z_intermediate)
1383
+ start_frame = end_frame
1384
+ end_frame += self.num_latent_frames_batch_size
1385
+
1386
+ dec = torch.cat(dec, dim=2)
1387
+
1388
+ if not return_dict:
1389
+ return (dec,)
1390
+
1391
+ return DecoderOutput(sample=dec)
1392
+
1393
+ @apply_forward_hook
1394
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1395
+ """
1396
+ Decode a batch of images.
1397
+
1398
+ Args:
1399
+ z (`torch.Tensor`): Input batch of latent vectors.
1400
+ return_dict (`bool`, *optional*, defaults to `True`):
1401
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1402
+
1403
+ Returns:
1404
+ [`~models.vae.DecoderOutput`] or `tuple`:
1405
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1406
+ returned.
1407
+ """
1408
+ if self.use_slicing and z.shape[0] > 1:
1409
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1410
+ decoded = torch.cat(decoded_slices)
1411
+ else:
1412
+ decoded = self._decode(z).sample
1413
+
1414
+ if not return_dict:
1415
+ return (decoded,)
1416
+ return DecoderOutput(sample=decoded)
1417
+
1418
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1419
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1420
+ for y in range(blend_extent):
1421
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1422
+ y / blend_extent
1423
+ )
1424
+ return b
1425
+
1426
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1427
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1428
+ for x in range(blend_extent):
1429
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1430
+ x / blend_extent
1431
+ )
1432
+ return b
1433
+
1434
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1435
+ r"""Encode a batch of images using a tiled encoder.
1436
+
1437
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1438
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1439
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1440
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1441
+ output, but they should be much less noticeable.
1442
+
1443
+ Args:
1444
+ x (`torch.Tensor`): Input batch of videos.
1445
+
1446
+ Returns:
1447
+ `torch.Tensor`:
1448
+ The latent representation of the encoded videos.
1449
+ """
1450
+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1451
+ batch_size, num_channels, num_frames, height, width = x.shape
1452
+
1453
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1454
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1455
+ blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1456
+ blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1457
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
1458
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
1459
+ frame_batch_size = self.num_sample_frames_batch_size
1460
+
1461
+ # Split x into overlapping tiles and encode them separately.
1462
+ # The tiles have an overlap to avoid seams between tiles.
1463
+ rows = []
1464
+ for i in range(0, height, overlap_height):
1465
+ row = []
1466
+ for j in range(0, width, overlap_width):
1467
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1468
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1469
+ num_batches = max(num_frames // frame_batch_size, 1)
1470
+ conv_cache = None
1471
+ time = []
1472
+
1473
+ for k in range(num_batches):
1474
+ remaining_frames = num_frames % frame_batch_size
1475
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1476
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1477
+ tile = x[
1478
+ :,
1479
+ :,
1480
+ start_frame:end_frame,
1481
+ i : i + self.tile_sample_min_height,
1482
+ j : j + self.tile_sample_min_width,
1483
+ ]
1484
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1485
+ if self.quant_conv is not None:
1486
+ tile = self.quant_conv(tile)
1487
+ time.append(tile)
1488
+
1489
+ row.append(torch.cat(time, dim=2))
1490
+ rows.append(row)
1491
+
1492
+ result_rows = []
1493
+ for i, row in enumerate(rows):
1494
+ result_row = []
1495
+ for j, tile in enumerate(row):
1496
+ # blend the above tile and the left tile
1497
+ # to the current tile and add the current tile to the result row
1498
+ if i > 0:
1499
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1500
+ if j > 0:
1501
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1502
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1503
+ result_rows.append(torch.cat(result_row, dim=4))
1504
+
1505
+ enc = torch.cat(result_rows, dim=3)
1506
+ return enc
1507
+
1508
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1509
+ r"""
1510
+ Decode a batch of images using a tiled decoder.
1511
+
1512
+ Args:
1513
+ z (`torch.Tensor`): Input batch of latent vectors.
1514
+ return_dict (`bool`, *optional*, defaults to `True`):
1515
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1516
+
1517
+ Returns:
1518
+ [`~models.vae.DecoderOutput`] or `tuple`:
1519
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1520
+ returned.
1521
+ """
1522
+ # Rough memory assessment:
1523
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1524
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1525
+ # - Assume fp16 (2 bytes per value).
1526
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1527
+ #
1528
+ # Memory assessment when using tiling:
1529
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1530
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1531
+
1532
+ batch_size, num_channels, num_frames, height, width = z.shape
1533
+
1534
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1535
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1536
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1537
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1538
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1539
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1540
+ frame_batch_size = self.num_latent_frames_batch_size
1541
+
1542
+ # Split z into overlapping tiles and decode them separately.
1543
+ # The tiles have an overlap to avoid seams between tiles.
1544
+ rows = []
1545
+ for i in range(0, height, overlap_height):
1546
+ row = []
1547
+ for j in range(0, width, overlap_width):
1548
+ if self.auto_split_process:
1549
+ num_batches = max(num_frames // frame_batch_size, 1)
1550
+ conv_cache = None
1551
+ time = []
1552
+
1553
+ for k in range(num_batches):
1554
+ remaining_frames = num_frames % frame_batch_size
1555
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1556
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1557
+ tile = z[
1558
+ :,
1559
+ :,
1560
+ start_frame:end_frame,
1561
+ i : i + self.tile_latent_min_height,
1562
+ j : j + self.tile_latent_min_width,
1563
+ ]
1564
+ if self.post_quant_conv is not None:
1565
+ tile = self.post_quant_conv(tile)
1566
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1567
+ time.append(tile)
1568
+
1569
+ row.append(torch.cat(time, dim=2))
1570
+ else:
1571
+ conv_cache = None
1572
+ start_frame = 0
1573
+ end_frame = 1
1574
+ dec = []
1575
+
1576
+ tile = z[
1577
+ :,
1578
+ :,
1579
+ start_frame:end_frame,
1580
+ i : i + self.tile_latent_min_height,
1581
+ j : j + self.tile_latent_min_width,
1582
+ ]
1583
+
1584
+ self._set_first_frame()
1585
+ if self.post_quant_conv is not None:
1586
+ tile = self.post_quant_conv(tile)
1587
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1588
+ dec.append(tile)
1589
+
1590
+ self._set_rest_frame()
1591
+ start_frame = end_frame
1592
+ end_frame += self.num_latent_frames_batch_size
1593
+
1594
+ while start_frame < num_frames:
1595
+ tile = z[
1596
+ :,
1597
+ :,
1598
+ start_frame:end_frame,
1599
+ i : i + self.tile_latent_min_height,
1600
+ j : j + self.tile_latent_min_width,
1601
+ ]
1602
+ if self.post_quant_conv is not None:
1603
+ tile = self.post_quant_conv(tile)
1604
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1605
+ dec.append(tile)
1606
+ start_frame = end_frame
1607
+ end_frame += self.num_latent_frames_batch_size
1608
+
1609
+ row.append(torch.cat(dec, dim=2))
1610
+ rows.append(row)
1611
+
1612
+ result_rows = []
1613
+ for i, row in enumerate(rows):
1614
+ result_row = []
1615
+ for j, tile in enumerate(row):
1616
+ # blend the above tile and the left tile
1617
+ # to the current tile and add the current tile to the result row
1618
+ if i > 0:
1619
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1620
+ if j > 0:
1621
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1622
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1623
+ result_rows.append(torch.cat(result_row, dim=4))
1624
+
1625
+ dec = torch.cat(result_rows, dim=3)
1626
+
1627
+ if not return_dict:
1628
+ return (dec,)
1629
+
1630
+ return DecoderOutput(sample=dec)
1631
+
1632
+ def forward(
1633
+ self,
1634
+ sample: torch.Tensor,
1635
+ sample_posterior: bool = False,
1636
+ return_dict: bool = True,
1637
+ generator: Optional[torch.Generator] = None,
1638
+ ) -> Union[torch.Tensor, torch.Tensor]:
1639
+ x = sample
1640
+ posterior = self.encode(x).latent_dist
1641
+ if sample_posterior:
1642
+ z = posterior.sample(generator=generator)
1643
+ else:
1644
+ z = posterior.mode()
1645
+ dec = self.decode(z)
1646
+ if not return_dict:
1647
+ return (dec,)
1648
+ return dec
1649
+
1650
+ @classmethod
1651
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
1652
+ if subfolder is not None:
1653
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1654
+
1655
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1656
+ if not os.path.isfile(config_file):
1657
+ raise RuntimeError(f"{config_file} does not exist")
1658
+ with open(config_file, "r") as f:
1659
+ config = json.load(f)
1660
+
1661
+ model = cls.from_config(config, **vae_additional_kwargs)
1662
+ from diffusers.utils import WEIGHTS_NAME
1663
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1664
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1665
+ if os.path.exists(model_file_safetensors):
1666
+ from safetensors.torch import load_file, safe_open
1667
+ state_dict = load_file(model_file_safetensors)
1668
+ else:
1669
+ if not os.path.isfile(model_file):
1670
+ raise RuntimeError(f"{model_file} does not exist")
1671
+ state_dict = torch.load(model_file, map_location="cpu")
1672
+ m, u = model.load_state_dict(state_dict, strict=False)
1673
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1674
+ print(m, u)
1675
+ return model
diffusers/dist_utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+ from diffusers.models.attention import Attention
7
+ from diffusers.models.embeddings import apply_rotary_emb
8
+
9
+ try:
10
+ import xfuser
11
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
12
+ get_sequence_parallel_world_size,
13
+ get_sp_group, get_world_group,
14
+ init_distributed_environment,
15
+ initialize_model_parallel)
16
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
17
+ except Exception as ex:
18
+ get_sequence_parallel_world_size = None
19
+ get_sequence_parallel_rank = None
20
+ xFuserLongContextAttention = None
21
+ get_sp_group = None
22
+ get_world_group = None
23
+ init_distributed_environment = None
24
+ initialize_model_parallel = None
25
+
26
+ def set_multi_gpus_devices(ulysses_degree, ring_degree):
27
+ if ulysses_degree > 1 or ring_degree > 1:
28
+ if get_sp_group is None:
29
+ raise RuntimeError("xfuser is not installed.")
30
+ dist.init_process_group("nccl")
31
+ print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
32
+ ulysses_degree, ring_degree, dist.get_rank(),
33
+ dist.get_world_size()))
34
+ assert dist.get_world_size() == ring_degree * ulysses_degree, \
35
+ "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
36
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
37
+ initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
38
+ ring_degree=ring_degree,
39
+ ulysses_degree=ulysses_degree)
40
+ # device = torch.device("cuda:%d" % dist.get_rank())
41
+ device = torch.device(f"cuda:{get_world_group().local_rank}")
42
+ print('rank=%d device=%s' % (get_world_group().rank, str(device)))
43
+ else:
44
+ device = "cuda"
45
+ return device
46
+
47
+ class CogVideoXMultiGPUsAttnProcessor2_0:
48
+ r"""
49
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
50
+ query and key vectors, but does not include spatial normalization.
51
+ """
52
+
53
+ def __init__(self):
54
+ if xFuserLongContextAttention is not None:
55
+ try:
56
+ self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
57
+ except Exception:
58
+ self.hybrid_seq_parallel_attn = None
59
+ else:
60
+ self.hybrid_seq_parallel_attn = None
61
+ if not hasattr(F, "scaled_dot_product_attention"):
62
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
63
+
64
+ def __call__(
65
+ self,
66
+ attn: Attention,
67
+ hidden_states: torch.Tensor,
68
+ encoder_hidden_states: torch.Tensor,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ image_rotary_emb: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ text_seq_length = encoder_hidden_states.size(1)
73
+
74
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
75
+
76
+ batch_size, sequence_length, _ = (
77
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
78
+ )
79
+
80
+ if attention_mask is not None:
81
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
82
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
83
+
84
+ query = attn.to_q(hidden_states)
85
+ key = attn.to_k(hidden_states)
86
+ value = attn.to_v(hidden_states)
87
+
88
+ inner_dim = key.shape[-1]
89
+ head_dim = inner_dim // attn.heads
90
+
91
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
92
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
93
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
94
+
95
+ if attn.norm_q is not None:
96
+ query = attn.norm_q(query)
97
+ if attn.norm_k is not None:
98
+ key = attn.norm_k(key)
99
+
100
+ # Apply RoPE if needed
101
+ if image_rotary_emb is not None:
102
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
103
+ if not attn.is_cross_attention:
104
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
105
+
106
+ if self.hybrid_seq_parallel_attn is None:
107
+ hidden_states = F.scaled_dot_product_attention(
108
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
109
+ )
110
+ hidden_states = hidden_states
111
+ else:
112
+ img_q = query[:, :, text_seq_length:].transpose(1, 2)
113
+ txt_q = query[:, :, :text_seq_length].transpose(1, 2)
114
+ img_k = key[:, :, text_seq_length:].transpose(1, 2)
115
+ txt_k = key[:, :, :text_seq_length].transpose(1, 2)
116
+ img_v = value[:, :, text_seq_length:].transpose(1, 2)
117
+ txt_v = value[:, :, :text_seq_length].transpose(1, 2)
118
+
119
+ hidden_states = self.hybrid_seq_parallel_attn(
120
+ None,
121
+ img_q, img_k, img_v, dropout_p=0.0, causal=False,
122
+ joint_tensor_query=txt_q,
123
+ joint_tensor_key=txt_k,
124
+ joint_tensor_value=txt_v,
125
+ joint_strategy='front',
126
+ ).transpose(1, 2)
127
+
128
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
129
+
130
+ # linear proj
131
+ hidden_states = attn.to_out[0](hidden_states)
132
+ # dropout
133
+ hidden_states = attn.to_out[1](hidden_states)
134
+
135
+ encoder_hidden_states, hidden_states = hidden_states.split(
136
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
137
+ )
138
+ return hidden_states, encoder_hidden_states
diffusers/pipeline_cogvideox_fun_inpaint.py ADDED
@@ -0,0 +1,1244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from einops import rearrange
33
+
34
+ from transformers import T5EncoderModel, T5Tokenizer
35
+ from cogvideox_transformer3d import CogVideoXTransformer3DModel
36
+ from cogvideox_vae import AutoencoderKLCogVideoX
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```python
44
+ pass
45
+ ```
46
+ """
47
+
48
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
49
+ def get_3d_rotary_pos_embed(
50
+ embed_dim,
51
+ crops_coords,
52
+ grid_size,
53
+ temporal_size,
54
+ theta: int = 10000,
55
+ use_real: bool = True,
56
+ grid_type: str = "linspace",
57
+ max_size: Optional[Tuple[int, int]] = None,
58
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
59
+ """
60
+ RoPE for video tokens with 3D structure.
61
+
62
+ Args:
63
+ embed_dim: (`int`):
64
+ The embedding dimension size, corresponding to hidden_size_head.
65
+ crops_coords (`Tuple[int]`):
66
+ The top-left and bottom-right coordinates of the crop.
67
+ grid_size (`Tuple[int]`):
68
+ The grid size of the spatial positional embedding (height, width).
69
+ temporal_size (`int`):
70
+ The size of the temporal dimension.
71
+ theta (`float`):
72
+ Scaling factor for frequency computation.
73
+ grid_type (`str`):
74
+ Whether to use "linspace" or "slice" to compute grids.
75
+
76
+ Returns:
77
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
78
+ """
79
+ if use_real is not True:
80
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
81
+
82
+ if grid_type == "linspace":
83
+ start, stop = crops_coords
84
+ grid_size_h, grid_size_w = grid_size
85
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
86
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
87
+ grid_t = np.arange(temporal_size, dtype=np.float32)
88
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
89
+ elif grid_type == "slice":
90
+ max_h, max_w = max_size
91
+ grid_size_h, grid_size_w = grid_size
92
+ grid_h = np.arange(max_h, dtype=np.float32)
93
+ grid_w = np.arange(max_w, dtype=np.float32)
94
+ grid_t = np.arange(temporal_size, dtype=np.float32)
95
+ else:
96
+ raise ValueError("Invalid value passed for `grid_type`.")
97
+
98
+ # Compute dimensions for each axis
99
+ dim_t = embed_dim // 4
100
+ dim_h = embed_dim // 8 * 3
101
+ dim_w = embed_dim // 8 * 3
102
+
103
+ # Temporal frequencies
104
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
105
+ # Spatial frequencies for height and width
106
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
107
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
108
+
109
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
110
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
111
+ freqs_t = freqs_t[:, None, None, :].expand(
112
+ -1, grid_size_h, grid_size_w, -1
113
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
114
+ freqs_h = freqs_h[None, :, None, :].expand(
115
+ temporal_size, -1, grid_size_w, -1
116
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
117
+ freqs_w = freqs_w[None, None, :, :].expand(
118
+ temporal_size, grid_size_h, -1, -1
119
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
120
+
121
+ freqs = torch.cat(
122
+ [freqs_t, freqs_h, freqs_w], dim=-1
123
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
124
+ freqs = freqs.view(
125
+ temporal_size * grid_size_h * grid_size_w, -1
126
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
127
+ return freqs
128
+
129
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
130
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
131
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
132
+
133
+ if grid_type == "slice":
134
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
135
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
136
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
137
+
138
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
139
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
140
+ return cos, sin
141
+
142
+
143
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
144
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
145
+ tw = tgt_width
146
+ th = tgt_height
147
+ h, w = src
148
+ r = h / w
149
+ if r > (th / tw):
150
+ resize_height = th
151
+ resize_width = int(round(th / h * w))
152
+ else:
153
+ resize_width = tw
154
+ resize_height = int(round(tw / w * h))
155
+
156
+ crop_top = int(round((th - resize_height) / 2.0))
157
+ crop_left = int(round((tw - resize_width) / 2.0))
158
+
159
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
160
+
161
+
162
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
163
+ def retrieve_timesteps(
164
+ scheduler,
165
+ num_inference_steps: Optional[int] = None,
166
+ device: Optional[Union[str, torch.device]] = None,
167
+ timesteps: Optional[List[int]] = None,
168
+ sigmas: Optional[List[float]] = None,
169
+ **kwargs,
170
+ ):
171
+ """
172
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
173
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
174
+
175
+ Args:
176
+ scheduler (`SchedulerMixin`):
177
+ The scheduler to get timesteps from.
178
+ num_inference_steps (`int`):
179
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
180
+ must be `None`.
181
+ device (`str` or `torch.device`, *optional*):
182
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
183
+ timesteps (`List[int]`, *optional*):
184
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
185
+ `num_inference_steps` and `sigmas` must be `None`.
186
+ sigmas (`List[float]`, *optional*):
187
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
188
+ `num_inference_steps` and `timesteps` must be `None`.
189
+
190
+ Returns:
191
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
192
+ second element is the number of inference steps.
193
+ """
194
+ if timesteps is not None and sigmas is not None:
195
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
196
+ if timesteps is not None:
197
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
198
+ if not accepts_timesteps:
199
+ raise ValueError(
200
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
201
+ f" timestep schedules. Please check whether you are using the correct scheduler."
202
+ )
203
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
204
+ timesteps = scheduler.timesteps
205
+ num_inference_steps = len(timesteps)
206
+ elif sigmas is not None:
207
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
208
+ if not accept_sigmas:
209
+ raise ValueError(
210
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
211
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
212
+ )
213
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
214
+ timesteps = scheduler.timesteps
215
+ num_inference_steps = len(timesteps)
216
+ else:
217
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
218
+ timesteps = scheduler.timesteps
219
+ return timesteps, num_inference_steps
220
+
221
+
222
+ def resize_mask(mask, latent, process_first_frame_only=True):
223
+ latent_size = latent.size()
224
+ batch_size, channels, num_frames, height, width = mask.shape
225
+
226
+ if process_first_frame_only:
227
+ target_size = list(latent_size[2:])
228
+ target_size[0] = 1
229
+ first_frame_resized = F.interpolate(
230
+ mask[:, :, 0:1, :, :],
231
+ size=target_size,
232
+ mode='trilinear',
233
+ align_corners=False
234
+ )
235
+
236
+ target_size = list(latent_size[2:])
237
+ target_size[0] = target_size[0] - 1
238
+ if target_size[0] != 0:
239
+ remaining_frames_resized = F.interpolate(
240
+ mask[:, :, 1:, :, :],
241
+ size=target_size,
242
+ mode='trilinear',
243
+ align_corners=False
244
+ )
245
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
246
+ else:
247
+ resized_mask = first_frame_resized
248
+ else:
249
+ target_size = list(latent_size[2:])
250
+ resized_mask = F.interpolate(
251
+ mask,
252
+ size=target_size,
253
+ mode='trilinear',
254
+ align_corners=False
255
+ )
256
+ return resized_mask
257
+
258
+
259
+ def add_noise_to_reference_video(image, ratio=None):
260
+ if ratio is None:
261
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
262
+ sigma = torch.exp(sigma).to(image.dtype)
263
+ else:
264
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
265
+
266
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
267
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
268
+ image = image + image_noise
269
+ return image
270
+
271
+
272
+ @dataclass
273
+ class CogVideoXFunPipelineOutput(BaseOutput):
274
+ r"""
275
+ Output class for CogVideo pipelines.
276
+
277
+ Args:
278
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
279
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
280
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
281
+ `(batch_size, num_frames, channels, height, width)`.
282
+ """
283
+
284
+ videos: torch.Tensor
285
+
286
+
287
+ class CogVideoXFunInpaintPipeline(DiffusionPipeline):
288
+ r"""
289
+ Pipeline for text-to-video generation using CogVideoX.
290
+
291
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
292
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
293
+
294
+ Args:
295
+ vae ([`AutoencoderKL`]):
296
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
297
+ text_encoder ([`T5EncoderModel`]):
298
+ Frozen text-encoder. CogVideoX_Fun uses
299
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
300
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
301
+ tokenizer (`T5Tokenizer`):
302
+ Tokenizer of class
303
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
304
+ transformer ([`CogVideoXTransformer3DModel`]):
305
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
306
+ scheduler ([`SchedulerMixin`]):
307
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
308
+ """
309
+
310
+ _optional_components = []
311
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
312
+
313
+ _callback_tensor_inputs = [
314
+ "latents",
315
+ "prompt_embeds",
316
+ "negative_prompt_embeds",
317
+ ]
318
+
319
+ def __init__(
320
+ self,
321
+ tokenizer: T5Tokenizer,
322
+ text_encoder: T5EncoderModel,
323
+ vae: AutoencoderKLCogVideoX,
324
+ transformer: CogVideoXTransformer3DModel,
325
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
326
+ ):
327
+ super().__init__()
328
+
329
+ self.register_modules(
330
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
331
+ )
332
+ self.vae_scale_factor_spatial = (
333
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
334
+ )
335
+ self.vae_scale_factor_temporal = (
336
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
337
+ )
338
+
339
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
340
+
341
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
342
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
343
+ self.mask_processor = VaeImageProcessor(
344
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=False, do_convert_grayscale=True
345
+ )
346
+
347
+ def _get_t5_prompt_embeds(
348
+ self,
349
+ prompt: Union[str, List[str]] = None,
350
+ num_videos_per_prompt: int = 1,
351
+ max_sequence_length: int = 226,
352
+ device: Optional[torch.device] = None,
353
+ dtype: Optional[torch.dtype] = None,
354
+ ):
355
+ device = device or self._execution_device
356
+ dtype = dtype or self.text_encoder.dtype
357
+
358
+ prompt = [prompt] if isinstance(prompt, str) else prompt
359
+ batch_size = len(prompt)
360
+
361
+ text_inputs = self.tokenizer(
362
+ prompt,
363
+ padding="max_length",
364
+ max_length=max_sequence_length,
365
+ truncation=True,
366
+ add_special_tokens=True,
367
+ return_tensors="pt",
368
+ )
369
+ text_input_ids = text_inputs.input_ids
370
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
371
+
372
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
373
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
374
+ logger.warning(
375
+ "The following part of your input was truncated because `max_sequence_length` is set to "
376
+ f" {max_sequence_length} tokens: {removed_text}"
377
+ )
378
+
379
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
380
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
381
+
382
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
383
+ _, seq_len, _ = prompt_embeds.shape
384
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
385
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
386
+
387
+ return prompt_embeds
388
+
389
+ def encode_prompt(
390
+ self,
391
+ prompt: Union[str, List[str]],
392
+ negative_prompt: Optional[Union[str, List[str]]] = None,
393
+ do_classifier_free_guidance: bool = True,
394
+ num_videos_per_prompt: int = 1,
395
+ prompt_embeds: Optional[torch.Tensor] = None,
396
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
397
+ max_sequence_length: int = 226,
398
+ device: Optional[torch.device] = None,
399
+ dtype: Optional[torch.dtype] = None,
400
+ ):
401
+ r"""
402
+ Encodes the prompt into text encoder hidden states.
403
+
404
+ Args:
405
+ prompt (`str` or `List[str]`, *optional*):
406
+ prompt to be encoded
407
+ negative_prompt (`str` or `List[str]`, *optional*):
408
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
409
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
410
+ less than `1`).
411
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
412
+ Whether to use classifier free guidance or not.
413
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
414
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
415
+ prompt_embeds (`torch.Tensor`, *optional*):
416
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
417
+ provided, text embeddings will be generated from `prompt` input argument.
418
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
419
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
420
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
421
+ argument.
422
+ device: (`torch.device`, *optional*):
423
+ torch device
424
+ dtype: (`torch.dtype`, *optional*):
425
+ torch dtype
426
+ """
427
+ device = device or self._execution_device
428
+
429
+ prompt = [prompt] if isinstance(prompt, str) else prompt
430
+ if prompt is not None:
431
+ batch_size = len(prompt)
432
+ else:
433
+ batch_size = prompt_embeds.shape[0]
434
+
435
+ if prompt_embeds is None:
436
+ prompt_embeds = self._get_t5_prompt_embeds(
437
+ prompt=prompt,
438
+ num_videos_per_prompt=num_videos_per_prompt,
439
+ max_sequence_length=max_sequence_length,
440
+ device=device,
441
+ dtype=dtype,
442
+ )
443
+
444
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
445
+ negative_prompt = negative_prompt or ""
446
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
447
+
448
+ if prompt is not None and type(prompt) is not type(negative_prompt):
449
+ raise TypeError(
450
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
451
+ f" {type(prompt)}."
452
+ )
453
+ elif batch_size != len(negative_prompt):
454
+ raise ValueError(
455
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
456
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
457
+ " the batch size of `prompt`."
458
+ )
459
+
460
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
461
+ prompt=negative_prompt,
462
+ num_videos_per_prompt=num_videos_per_prompt,
463
+ max_sequence_length=max_sequence_length,
464
+ device=device,
465
+ dtype=dtype,
466
+ )
467
+
468
+ return prompt_embeds, negative_prompt_embeds
469
+
470
+ def prepare_latents(
471
+ self,
472
+ batch_size,
473
+ num_channels_latents,
474
+ height,
475
+ width,
476
+ video_length,
477
+ dtype,
478
+ device,
479
+ generator,
480
+ latents=None,
481
+ video=None,
482
+ timestep=None,
483
+ is_strength_max=True,
484
+ return_noise=False,
485
+ return_video_latents=False,
486
+ ):
487
+ shape = (
488
+ batch_size,
489
+ (video_length - 1) // self.vae_scale_factor_temporal + 1,
490
+ num_channels_latents,
491
+ height // self.vae_scale_factor_spatial,
492
+ width // self.vae_scale_factor_spatial,
493
+ )
494
+ if isinstance(generator, list) and len(generator) != batch_size:
495
+ raise ValueError(
496
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
497
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
498
+ )
499
+
500
+ if return_video_latents or (latents is None and not is_strength_max):
501
+ video = video.to(device=device, dtype=self.vae.dtype)
502
+
503
+ bs = 1
504
+ new_video = []
505
+ for i in range(0, video.shape[0], bs):
506
+ video_bs = video[i : i + bs]
507
+ video_bs = self.vae.encode(video_bs)[0]
508
+ video_bs = video_bs.sample()
509
+ new_video.append(video_bs)
510
+ video = torch.cat(new_video, dim = 0)
511
+ video = video * self.vae.config.scaling_factor
512
+
513
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
514
+ video_latents = video_latents.to(device=device, dtype=dtype)
515
+ video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
516
+
517
+ if latents is None:
518
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
519
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
520
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
521
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
522
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
523
+ else:
524
+ noise = latents.to(device)
525
+ latents = noise * self.scheduler.init_noise_sigma
526
+
527
+ # scale the initial noise by the standard deviation required by the scheduler
528
+ outputs = (latents,)
529
+
530
+ if return_noise:
531
+ outputs += (noise,)
532
+
533
+ if return_video_latents:
534
+ outputs += (video_latents,)
535
+
536
+ return outputs
537
+
538
+ def prepare_mask_latents(
539
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
540
+ ):
541
+ # resize the mask to latents shape as we concatenate the mask to the latents
542
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
543
+ # and half precision
544
+
545
+ if mask is not None:
546
+ mask = mask.to(device=device, dtype=self.vae.dtype)
547
+ bs = 1
548
+ new_mask = []
549
+ for i in range(0, mask.shape[0], bs):
550
+ mask_bs = mask[i : i + bs]
551
+ mask_bs = self.vae.encode(mask_bs)[0]
552
+ mask_bs = mask_bs.mode()
553
+ new_mask.append(mask_bs)
554
+ mask = torch.cat(new_mask, dim = 0)
555
+ mask = mask * self.vae.config.scaling_factor
556
+
557
+ if masked_image is not None:
558
+ if self.transformer.config.add_noise_in_inpaint_model:
559
+ masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
560
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
561
+ bs = 1
562
+ new_mask_pixel_values = []
563
+ for i in range(0, masked_image.shape[0], bs):
564
+ mask_pixel_values_bs = masked_image[i : i + bs]
565
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
566
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
567
+ new_mask_pixel_values.append(mask_pixel_values_bs)
568
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
569
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
570
+ else:
571
+ masked_image_latents = None
572
+
573
+ return mask, masked_image_latents
574
+
575
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
576
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
577
+ latents = 1 / self.vae.config.scaling_factor * latents
578
+
579
+ frames = self.vae.decode(latents).sample
580
+ frames = (frames / 2 + 0.5).clamp(0, 1)
581
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
582
+ frames = frames.cpu().float().numpy()
583
+ return frames
584
+
585
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
586
+ def prepare_extra_step_kwargs(self, generator, eta):
587
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
588
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
589
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
590
+ # and should be between [0, 1]
591
+
592
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
593
+ extra_step_kwargs = {}
594
+ if accepts_eta:
595
+ extra_step_kwargs["eta"] = eta
596
+
597
+ # check if the scheduler accepts generator
598
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
599
+ if accepts_generator:
600
+ extra_step_kwargs["generator"] = generator
601
+ return extra_step_kwargs
602
+
603
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
604
+ def check_inputs(
605
+ self,
606
+ prompt,
607
+ height,
608
+ width,
609
+ negative_prompt,
610
+ callback_on_step_end_tensor_inputs,
611
+ prompt_embeds=None,
612
+ negative_prompt_embeds=None,
613
+ ):
614
+ if height % 8 != 0 or width % 8 != 0:
615
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
616
+
617
+ if callback_on_step_end_tensor_inputs is not None and not all(
618
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
619
+ ):
620
+ raise ValueError(
621
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
622
+ )
623
+ if prompt is not None and prompt_embeds is not None:
624
+ raise ValueError(
625
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
626
+ " only forward one of the two."
627
+ )
628
+ elif prompt is None and prompt_embeds is None:
629
+ raise ValueError(
630
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
631
+ )
632
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
633
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
634
+
635
+ if prompt is not None and negative_prompt_embeds is not None:
636
+ raise ValueError(
637
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
638
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
639
+ )
640
+
641
+ if negative_prompt is not None and negative_prompt_embeds is not None:
642
+ raise ValueError(
643
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
644
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
645
+ )
646
+
647
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
648
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
649
+ raise ValueError(
650
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
651
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
652
+ f" {negative_prompt_embeds.shape}."
653
+ )
654
+
655
+ def fuse_qkv_projections(self) -> None:
656
+ r"""Enables fused QKV projections."""
657
+ self.fusing_transformer = True
658
+ self.transformer.fuse_qkv_projections()
659
+
660
+ def unfuse_qkv_projections(self) -> None:
661
+ r"""Disable QKV projection fusion if enabled."""
662
+ if not self.fusing_transformer:
663
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
664
+ else:
665
+ self.transformer.unfuse_qkv_projections()
666
+ self.fusing_transformer = False
667
+
668
+ def _prepare_rotary_positional_embeddings(
669
+ self,
670
+ height: int,
671
+ width: int,
672
+ num_frames: int,
673
+ device: torch.device,
674
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
675
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
676
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
677
+
678
+ p = self.transformer.config.patch_size
679
+ p_t = self.transformer.config.patch_size_t
680
+
681
+ base_size_width = self.transformer.config.sample_width // p
682
+ base_size_height = self.transformer.config.sample_height // p
683
+
684
+ if p_t is None:
685
+ # CogVideoX 1.0
686
+ grid_crops_coords = get_resize_crop_region_for_grid(
687
+ (grid_height, grid_width), base_size_width, base_size_height
688
+ )
689
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
690
+ embed_dim=self.transformer.config.attention_head_dim,
691
+ crops_coords=grid_crops_coords,
692
+ grid_size=(grid_height, grid_width),
693
+ temporal_size=num_frames,
694
+ )
695
+ else:
696
+ # CogVideoX 1.5
697
+ base_num_frames = (num_frames + p_t - 1) // p_t
698
+
699
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
700
+ embed_dim=self.transformer.config.attention_head_dim,
701
+ crops_coords=None,
702
+ grid_size=(grid_height, grid_width),
703
+ temporal_size=base_num_frames,
704
+ grid_type="slice",
705
+ max_size=(base_size_height, base_size_width),
706
+ )
707
+
708
+ freqs_cos = freqs_cos.to(device=device)
709
+ freqs_sin = freqs_sin.to(device=device)
710
+ return freqs_cos, freqs_sin
711
+
712
+ @property
713
+ def guidance_scale(self):
714
+ return self._guidance_scale
715
+
716
+ @property
717
+ def num_timesteps(self):
718
+ return self._num_timesteps
719
+
720
+ @property
721
+ def attention_kwargs(self):
722
+ return self._attention_kwargs
723
+
724
+ @property
725
+ def interrupt(self):
726
+ return self._interrupt
727
+
728
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
729
+ def get_timesteps(self, num_inference_steps, strength, device):
730
+ # get the original timestep using init_timestep
731
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
732
+
733
+ t_start = max(num_inference_steps - init_timestep, 0)
734
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
735
+
736
+ return timesteps, num_inference_steps - t_start
737
+
738
+ @torch.no_grad()
739
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
740
+ def __call__(
741
+ self,
742
+ prompt: Optional[Union[str, List[str]]] = None,
743
+ negative_prompt: Optional[Union[str, List[str]]] = None,
744
+ height: int = 480,
745
+ width: int = 720,
746
+ video: Union[torch.FloatTensor] = None,
747
+ mask_video: Union[torch.FloatTensor] = None,
748
+ masked_video_latents: Union[torch.FloatTensor] = None,
749
+ num_frames: int = 49,
750
+ num_inference_steps: int = 50,
751
+ timesteps: Optional[List[int]] = None,
752
+ guidance_scale: float = 6,
753
+ use_dynamic_cfg: bool = False,
754
+ num_videos_per_prompt: int = 1,
755
+ eta: float = 0.0,
756
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
757
+ latents: Optional[torch.FloatTensor] = None,
758
+ prompt_embeds: Optional[torch.FloatTensor] = None,
759
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
760
+ output_type: str = "numpy",
761
+ return_dict: bool = False,
762
+ callback_on_step_end: Optional[
763
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
764
+ ] = None,
765
+ attention_kwargs: Optional[Dict[str, Any]] = None,
766
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
767
+ max_sequence_length: int = 226,
768
+ strength: float = 1,
769
+ noise_aug_strength: float = 0.0563,
770
+ comfyui_progressbar: bool = False,
771
+ temporal_multidiffusion_stride: int = 16,
772
+ use_trimask: bool = False,
773
+ zero_out_mask_region: bool = False,
774
+ binarize_mask: bool = False,
775
+ skip_unet: bool = False,
776
+ use_vae_mask: bool = False,
777
+ stack_mask: bool = False,
778
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
779
+ """
780
+ Function invoked when calling the pipeline for generation.
781
+
782
+ Args:
783
+ prompt (`str` or `List[str]`, *optional*):
784
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
785
+ instead.
786
+ negative_prompt (`str` or `List[str]`, *optional*):
787
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
788
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
789
+ less than `1`).
790
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
791
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
792
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
793
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
794
+ num_frames (`int`, defaults to `48`):
795
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
796
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
797
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
798
+ needs to be satisfied is that of divisibility mentioned above.
799
+ num_inference_steps (`int`, *optional*, defaults to 50):
800
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
801
+ expense of slower inference.
802
+ timesteps (`List[int]`, *optional*):
803
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
804
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
805
+ passed will be used. Must be in descending order.
806
+ guidance_scale (`float`, *optional*, defaults to 7.0):
807
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
808
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
809
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
810
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
811
+ usually at the expense of lower image quality.
812
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
813
+ The number of videos to generate per prompt.
814
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
815
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
816
+ to make generation deterministic.
817
+ latents (`torch.FloatTensor`, *optional*):
818
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
819
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
820
+ tensor will ge generated by sampling using the supplied random `generator`.
821
+ prompt_embeds (`torch.FloatTensor`, *optional*):
822
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
823
+ provided, text embeddings will be generated from `prompt` input argument.
824
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
825
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
826
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
827
+ argument.
828
+ output_type (`str`, *optional*, defaults to `"pil"`):
829
+ The output format of the generate image. Choose between
830
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
831
+ return_dict (`bool`, *optional*, defaults to `True`):
832
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
833
+ of a plain tuple.
834
+ callback_on_step_end (`Callable`, *optional*):
835
+ A function that calls at the end of each denoising steps during the inference. The function is called
836
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
837
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
838
+ `callback_on_step_end_tensor_inputs`.
839
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
840
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
841
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
842
+ `._callback_tensor_inputs` attribute of your pipeline class.
843
+ max_sequence_length (`int`, defaults to `226`):
844
+ Maximum sequence length in encoded prompt. Must be consistent with
845
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
846
+
847
+ Examples:
848
+
849
+ Returns:
850
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
851
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
852
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
853
+ """
854
+
855
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
856
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
857
+
858
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
859
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
860
+ num_frames = num_frames or self.transformer.config.sample_frames
861
+
862
+ num_videos_per_prompt = 1
863
+
864
+ # 1. Check inputs. Raise error if not correct
865
+ self.check_inputs(
866
+ prompt,
867
+ height,
868
+ width,
869
+ negative_prompt,
870
+ callback_on_step_end_tensor_inputs,
871
+ prompt_embeds,
872
+ negative_prompt_embeds,
873
+ )
874
+ self._guidance_scale = guidance_scale
875
+ self._attention_kwargs = attention_kwargs
876
+ self._interrupt = False
877
+
878
+ # 2. Default call parameters
879
+ if prompt is not None and isinstance(prompt, str):
880
+ batch_size = 1
881
+ elif prompt is not None and isinstance(prompt, list):
882
+ batch_size = len(prompt)
883
+ else:
884
+ batch_size = prompt_embeds.shape[0]
885
+
886
+ device = self._execution_device
887
+
888
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
889
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
890
+ # corresponds to doing no classifier free guidance.
891
+ do_classifier_free_guidance = guidance_scale > 1.0
892
+ logger.info(f'Use cfg: {do_classifier_free_guidance}, guidance_scale={guidance_scale}')
893
+
894
+ # 3. Encode input prompt
895
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
896
+ prompt,
897
+ negative_prompt,
898
+ do_classifier_free_guidance,
899
+ num_videos_per_prompt=num_videos_per_prompt,
900
+ prompt_embeds=prompt_embeds,
901
+ negative_prompt_embeds=negative_prompt_embeds,
902
+ max_sequence_length=max_sequence_length,
903
+ device=device,
904
+ )
905
+ if do_classifier_free_guidance:
906
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
907
+
908
+ # 4. set timesteps
909
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
910
+ timesteps, num_inference_steps = self.get_timesteps(
911
+ num_inference_steps=num_inference_steps, strength=strength, device=device
912
+ )
913
+ self._num_timesteps = len(timesteps)
914
+ if comfyui_progressbar:
915
+ from comfy.utils import ProgressBar
916
+ pbar = ProgressBar(num_inference_steps + 2)
917
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
918
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
919
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
920
+ is_strength_max = strength == 1.0
921
+
922
+ # 5. Prepare latents.
923
+ if video is not None:
924
+ video_length = video.shape[2]
925
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
926
+ init_video = init_video.to(dtype=torch.float32)
927
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
928
+ else:
929
+ video_length = num_frames
930
+ init_video = None
931
+
932
+ # Magvae needs the number of frames to be 4n + 1.
933
+ local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
934
+ # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
935
+ patch_size_t = self.transformer.config.patch_size_t
936
+ additional_frames = 0
937
+ if patch_size_t is not None and local_latent_length % patch_size_t != 0:
938
+ additional_frames = local_latent_length % patch_size_t
939
+ num_frames -= additional_frames * self.vae_scale_factor_temporal
940
+ if num_frames <= 0:
941
+ num_frames = 1
942
+
943
+ num_channels_latents = self.vae.config.latent_channels
944
+ num_channels_transformer = self.transformer.config.in_channels
945
+ return_image_latents = num_channels_transformer == num_channels_latents
946
+
947
+ latents_outputs = self.prepare_latents(
948
+ batch_size * num_videos_per_prompt,
949
+ num_channels_latents,
950
+ height,
951
+ width,
952
+ video_length,
953
+ prompt_embeds.dtype,
954
+ device,
955
+ generator,
956
+ latents,
957
+ video=init_video,
958
+ timestep=latent_timestep,
959
+ is_strength_max=is_strength_max,
960
+ return_noise=True,
961
+ return_video_latents=return_image_latents,
962
+ )
963
+ if return_image_latents:
964
+ latents, noise, image_latents = latents_outputs
965
+ else:
966
+ latents, noise = latents_outputs
967
+ if comfyui_progressbar:
968
+ pbar.update(1)
969
+
970
+ if mask_video is not None:
971
+ if (mask_video == 255).all():
972
+ mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
973
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
974
+
975
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
976
+ masked_video_latents_input = (
977
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
978
+ )
979
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
980
+ else:
981
+ # Prepare mask latent variables
982
+ video_length = video.shape[2]
983
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
984
+ if use_trimask:
985
+ mask_condition = torch.where(mask_condition > 0.75, 1., mask_condition)
986
+ mask_condition = torch.where((mask_condition <= 0.75) * (mask_condition >= 0.25), 127. / 255., mask_condition)
987
+ mask_condition = torch.where(mask_condition < 0.25, 0., mask_condition)
988
+ else:
989
+ mask_condition = torch.where(mask_condition > 0.5, 1., 0.)
990
+
991
+ mask_condition = mask_condition.to(dtype=torch.float32)
992
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
993
+
994
+ if num_channels_transformer != num_channels_latents:
995
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
996
+ if masked_video_latents is None:
997
+ if zero_out_mask_region:
998
+ masked_video = init_video * (mask_condition_tile < 0.75) + torch.ones_like(init_video) * (mask_condition_tile > 0.75) * -1
999
+ else:
1000
+ masked_video = init_video
1001
+ else:
1002
+ masked_video = masked_video_latents
1003
+
1004
+ mask_encoded, masked_video_latents = self.prepare_mask_latents(
1005
+ 1 - mask_condition_tile if use_vae_mask else None,
1006
+ masked_video,
1007
+ batch_size,
1008
+ height,
1009
+ width,
1010
+ prompt_embeds.dtype,
1011
+ device,
1012
+ generator,
1013
+ do_classifier_free_guidance,
1014
+ noise_aug_strength=noise_aug_strength,
1015
+ )
1016
+ if not use_vae_mask and not stack_mask:
1017
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
1018
+ if binarize_mask:
1019
+ if use_trimask:
1020
+ mask_latents = torch.where(mask_latents > 0.75, 1., mask_latents)
1021
+ mask_latents = torch.where((mask_latents <= 0.75) * (mask_latents >= 0.25), 0.5, mask_latents)
1022
+ mask_latents = torch.where(mask_latents < 0.25, 0., mask_latents)
1023
+ else:
1024
+ mask_latents = torch.where(mask_latents < 0.9, 0., 1.).to(mask_latents.dtype)
1025
+
1026
+ mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
1027
+
1028
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1029
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1030
+
1031
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
1032
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1033
+ elif stack_mask:
1034
+ mask_latents = torch.cat([
1035
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
1036
+ mask_condition[:, :, 1:],
1037
+ ], dim=2)
1038
+ mask_latents = mask_latents.view(
1039
+ mask_latents.shape[0],
1040
+ mask_latents.shape[2] // 4,
1041
+ 4,
1042
+ mask_latents.shape[3],
1043
+ mask_latents.shape[4],
1044
+ )
1045
+ mask_latents = mask_latents.transpose(1, 2)
1046
+ mask_latents = resize_mask(1 - mask_latents, masked_video_latents).to(latents.device, latents.dtype)
1047
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
1048
+ else:
1049
+ mask_input = (
1050
+ torch.cat([mask_encoded] * 2) if do_classifier_free_guidance else mask_encoded
1051
+ )
1052
+
1053
+ masked_video_latents_input = (
1054
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
1055
+ )
1056
+
1057
+ mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
1058
+ masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
1059
+
1060
+ # concat(binary mask, encode(mask * video))
1061
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
1062
+ else:
1063
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1064
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1065
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1066
+
1067
+ inpaint_latents = None
1068
+ else:
1069
+ if num_channels_transformer != num_channels_latents:
1070
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
1071
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1072
+
1073
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
1074
+ masked_video_latents_input = (
1075
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
1076
+ )
1077
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1078
+ else:
1079
+ mask = torch.zeros_like(init_video[:, :1])
1080
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
1081
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1082
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1083
+
1084
+ inpaint_latents = None
1085
+ if comfyui_progressbar:
1086
+ pbar.update(1)
1087
+
1088
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1089
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1090
+ logger.debug(f'Pipeline mask {mask_condition.shape} {mask_condition.dtype} {mask_condition.min()} {mask_condition.max()}')
1091
+ # 8. Denoising loop
1092
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1093
+ latent_temporal_window_size = (num_frames - 1) // 4 + 1
1094
+ if latents.size(1) > latent_temporal_window_size:
1095
+ logger.info(f'Adopt temporal multidiffusion for the latents {latents.shape} {latents.dtype}')
1096
+
1097
+ # VAE experiment
1098
+ if skip_unet:
1099
+ masked_video_latents = rearrange(masked_video_latents, "b c f h w -> b f c h w")
1100
+ if output_type == "numpy":
1101
+ video = self.decode_latents(masked_video_latents)
1102
+ elif not output_type == "latent":
1103
+ video = self.decode_latents(masked_video_latents)
1104
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
1105
+ else:
1106
+ video = masked_video_latents
1107
+
1108
+ # Offload all models
1109
+ self.maybe_free_model_hooks()
1110
+
1111
+ if not return_dict:
1112
+ video = torch.from_numpy(video)
1113
+
1114
+ return CogVideoXFunPipelineOutput(videos=video)
1115
+
1116
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1117
+ # for DPM-solver++
1118
+ old_pred_original_sample = None
1119
+ for i, t in enumerate(timesteps):
1120
+ if self.interrupt:
1121
+ continue
1122
+
1123
+ def _sample(_latents, _inpaint_latents):
1124
+ # 7. Create rotary embeds if required
1125
+ image_rotary_emb = (
1126
+ self._prepare_rotary_positional_embeddings(height, width, _latents.size(1), device)
1127
+ if self.transformer.config.use_rotary_positional_embeddings
1128
+ else None
1129
+ )
1130
+
1131
+ latent_model_input = torch.cat([_latents] * 2) if do_classifier_free_guidance else _latents
1132
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1133
+
1134
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1135
+ timestep = t.expand(latent_model_input.shape[0])
1136
+
1137
+ # predict noise model_output
1138
+ noise_pred = self.transformer(
1139
+ hidden_states=latent_model_input,
1140
+ encoder_hidden_states=prompt_embeds,
1141
+ timestep=timestep,
1142
+ image_rotary_emb=image_rotary_emb,
1143
+ return_dict=False,
1144
+ inpaint_latents=_inpaint_latents,
1145
+ )[0]
1146
+ noise_pred = noise_pred.float()
1147
+
1148
+ # perform guidance
1149
+ if use_dynamic_cfg:
1150
+ self._guidance_scale = 1 + guidance_scale * (
1151
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
1152
+ )
1153
+ if do_classifier_free_guidance:
1154
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1155
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1156
+
1157
+ # compute the previous noisy sample x_t -> x_t-1
1158
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
1159
+ _latents = self.scheduler.step(noise_pred, t, _latents, **extra_step_kwargs, return_dict=False)[0]
1160
+ else:
1161
+ _latents, old_pred_original_sample = self.scheduler.step(
1162
+ noise_pred,
1163
+ old_pred_original_sample,
1164
+ t,
1165
+ timesteps[i - 1] if i > 0 else None,
1166
+ _latents,
1167
+ **extra_step_kwargs,
1168
+ return_dict=False,
1169
+ )
1170
+ _latents = _latents.to(prompt_embeds.dtype)
1171
+ return _latents
1172
+
1173
+ if latents.size(1) <= latent_temporal_window_size:
1174
+ latents = _sample(latents, inpaint_latents)
1175
+ else:
1176
+ # adopt temporal multidiffusion
1177
+ latents_canvas = torch.zeros_like(latents).float()
1178
+ weights_canvas = torch.zeros(1, latents.size(1), 1, 1, 1).to(latents.device).float()
1179
+ temporal_stride = temporal_multidiffusion_stride // 4
1180
+ assert latent_temporal_window_size > temporal_stride
1181
+
1182
+ time_beg = 0
1183
+ while time_beg < latents.size(1):
1184
+ time_end = min(time_beg + latent_temporal_window_size, latents.size(1))
1185
+
1186
+ latents_i = latents[:, time_beg:time_end]
1187
+ if inpaint_latents is not None:
1188
+ inpaint_latents_i = inpaint_latents[:, time_beg:time_end]
1189
+ else:
1190
+ inpaint_latents_i = None
1191
+
1192
+ latents_i = _sample(latents_i, inpaint_latents_i)
1193
+
1194
+ weights_i = torch.ones(1, time_end - time_beg, 1, 1, 1).to(latents.device).to(latents.dtype)
1195
+ if time_beg > 0 and temporal_stride > 0:
1196
+ weights_i[:, :temporal_stride] = (torch.linspace(0., 1., temporal_stride + 2)[1:-1]
1197
+ .to(latents.device)
1198
+ .to(latents.dtype)
1199
+ .reshape(1, temporal_stride, 1, 1, 1))
1200
+ if time_end < latents.size(1) and temporal_stride > 0:
1201
+ weights_i[:, -temporal_stride:] = (torch.linspace(1., 0., temporal_stride + 2)[1:-1]
1202
+ .to(latents.device)
1203
+ .to(latents.dtype)
1204
+ .reshape(1, temporal_stride, 1, 1, 1))
1205
+
1206
+ latents_canvas[:, time_beg:time_end] += latents_i * weights_i
1207
+ weights_canvas[:, time_beg:time_end] += weights_i
1208
+
1209
+ time_beg = time_end - temporal_stride
1210
+ if time_end >= latents.size(1):
1211
+ break
1212
+ latents = (latents_canvas / weights_canvas).to(latents.dtype)
1213
+
1214
+ # call the callback, if provided
1215
+ if callback_on_step_end is not None:
1216
+ callback_kwargs = {}
1217
+ for k in callback_on_step_end_tensor_inputs:
1218
+ callback_kwargs[k] = locals()[k]
1219
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1220
+
1221
+ latents = callback_outputs.pop("latents", latents)
1222
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1223
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1224
+
1225
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1226
+ progress_bar.update()
1227
+ if comfyui_progressbar:
1228
+ pbar.update(1)
1229
+
1230
+ if output_type == "numpy":
1231
+ video = self.decode_latents(latents)
1232
+ elif not output_type == "latent":
1233
+ video = self.decode_latents(latents)
1234
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
1235
+ else:
1236
+ video = latents
1237
+
1238
+ # Offload all models
1239
+ self.maybe_free_model_hooks()
1240
+
1241
+ if not return_dict:
1242
+ video = torch.from_numpy(video)
1243
+
1244
+ return CogVideoXFunPipelineOutput(videos=video)
diffusers/pipeline_void.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VOID (Video Object and Interaction Deletion) Pipeline.
3
+
4
+ Simple usage:
5
+
6
+ from pipeline_void import VOIDPipeline
7
+
8
+ pipe = VOIDPipeline.from_pretrained("netflix/void-model")
9
+ result = pipe.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.")
10
+ result.save("output.mp4")
11
+
12
+ Pass 2 refinement:
13
+
14
+ pipe2 = VOIDPipeline.from_pretrained("netflix/void-model", void_pass=2)
15
+ result2 = pipe2.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.",
16
+ pass1_video="output.mp4")
17
+ result2.save("output_refined.mp4")
18
+ """
19
+
20
+ import os
21
+ import json
22
+ import subprocess
23
+ import sys
24
+ import tempfile
25
+ from dataclasses import dataclass
26
+ from typing import List, Optional, Tuple, Union
27
+
28
+ import cv2
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ from huggingface_hub import hf_hub_download, snapshot_download
33
+ from safetensors.torch import load_file
34
+ from diffusers import CogVideoXDDIMScheduler
35
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
36
+
37
+ from cogvideox_transformer3d import CogVideoXTransformer3DModel
38
+ from cogvideox_vae import AutoencoderKLCogVideoX
39
+ from pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
40
+
41
+ # The base model that VOID is fine-tuned from
42
+ BASE_MODEL_REPO = "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP"
43
+
44
+ # Checkpoint filenames in the VOID repo
45
+ PASS_CHECKPOINTS = {
46
+ 1: "void_pass1.safetensors",
47
+ 2: "void_pass2.safetensors",
48
+ }
49
+
50
+ # Default negative prompt (from config/quadmask_cogvideox.py)
51
+ DEFAULT_NEGATIVE_PROMPT = (
52
+ "The video is not of a high quality, it has a low resolution. "
53
+ "Watermark present in each frame. The background is solid. "
54
+ "Strange body and strange trajectory. Distortion. "
55
+ )
56
+
57
+
58
+ @dataclass
59
+ class VOIDOutput:
60
+ """Output from VOID pipeline."""
61
+ video: torch.Tensor # (T, H, W, 3) uint8
62
+ video_float: torch.Tensor # (1, C, T, H, W) float [0, 1]
63
+
64
+ def save(self, path: str, fps: int = 12):
65
+ """Save output video to file."""
66
+ import imageio
67
+ frames = [f for f in self.video.cpu().numpy()]
68
+ imageio.mimwrite(path, frames, fps=fps)
69
+ print(f"Saved {len(frames)} frames to {path}")
70
+
71
+
72
+ def _merge_void_weights(transformer, checkpoint_path):
73
+ """Merge VOID checkpoint into base transformer, handling channel mismatch."""
74
+ state_dict = load_file(checkpoint_path)
75
+ param_name = "patch_embed.proj.weight"
76
+
77
+ if state_dict[param_name].size(1) != transformer.state_dict()[param_name].size(1):
78
+ latent_ch = 16
79
+ feat_scale = 8
80
+ feat_dim = int(latent_ch * feat_scale)
81
+
82
+ new_weight = transformer.state_dict()[param_name].clone()
83
+ new_weight[:, :feat_dim] = state_dict[param_name][:, :feat_dim]
84
+ new_weight[:, -feat_dim:] = state_dict[param_name][:, -feat_dim:]
85
+ state_dict[param_name] = new_weight
86
+
87
+ m, u = transformer.load_state_dict(state_dict, strict=False)
88
+ if m:
89
+ print(f"[VOID] Missing keys: {len(m)}")
90
+ if u:
91
+ print(f"[VOID] Unexpected keys: {len(u)}")
92
+
93
+ return transformer
94
+
95
+
96
+ def _load_video(path: str, max_frames: int) -> np.ndarray:
97
+ """Load video as numpy array (T, H, W, 3) uint8."""
98
+ import imageio
99
+ frames = list(imageio.imiter(path))
100
+ frames = frames[:max_frames]
101
+ return np.array(frames)
102
+
103
+
104
+ def _prep_video_tensor(
105
+ video_np: np.ndarray,
106
+ sample_size: Tuple[int, int],
107
+ ) -> torch.Tensor:
108
+ """Convert video numpy array to pipeline input tensor.
109
+
110
+ Returns: (1, C, T, H, W) float32 in [0, 1]
111
+ """
112
+ video = torch.from_numpy(video_np).float()
113
+ video = video.permute(3, 0, 1, 2) / 255.0 # (C, T, H, W)
114
+ video = F.interpolate(video, sample_size, mode="area")
115
+ return video.unsqueeze(0) # (1, C, T, H, W)
116
+
117
+
118
+ def _prep_mask_tensor(
119
+ mask_np: np.ndarray,
120
+ sample_size: Tuple[int, int],
121
+ use_quadmask: bool = True,
122
+ ) -> torch.Tensor:
123
+ """Convert mask numpy array to pipeline input tensor.
124
+
125
+ Quantizes to quadmask values [0, 63, 127, 255], inverts,
126
+ and normalizes to [0, 1].
127
+
128
+ Returns: (1, 1, T, H, W) float32 in [0, 1]
129
+ """
130
+ mask = torch.from_numpy(mask_np).float()
131
+ if mask.ndim == 4:
132
+ mask = mask[..., 0] # drop channel dim -> (T, H, W)
133
+ mask = F.interpolate(mask.unsqueeze(0), sample_size, mode="area")
134
+ mask = mask.unsqueeze(0) # (1, 1, T, H, W)
135
+
136
+ if use_quadmask:
137
+ # Quantize to 4 values
138
+ mask = torch.where(mask <= 31, 0., mask)
139
+ mask = torch.where((mask > 31) * (mask <= 95), 63., mask)
140
+ mask = torch.where((mask > 95) * (mask <= 191), 127., mask)
141
+ mask = torch.where(mask > 191, 255., mask)
142
+ else:
143
+ # Trimask: 3 values
144
+ mask = torch.where(mask > 192, 255., mask)
145
+ mask = torch.where((mask <= 192) * (mask >= 64), 128., mask)
146
+ mask = torch.where(mask < 64, 0., mask)
147
+
148
+ # Invert and normalize to [0, 1]
149
+ mask = (255. - mask) / 255.
150
+
151
+ return mask
152
+
153
+
154
+ def _temporal_padding(
155
+ tensor: torch.Tensor,
156
+ min_length: int = 85,
157
+ max_length: int = 197,
158
+ dim: int = 2,
159
+ ) -> torch.Tensor:
160
+ """Pad video temporally by mirroring, matching CogVideoX requirements."""
161
+ length = tensor.size(dim)
162
+
163
+ min_len = (length // 4) * 4 + 1
164
+ if min_len < length:
165
+ min_len += 4
166
+ if (min_len / 4) % 2 == 0:
167
+ min_len += 4
168
+ target_length = min(min_len, max_length)
169
+ target_length = max(min_length, target_length)
170
+
171
+ # Truncate if needed
172
+ if dim == 2:
173
+ tensor = tensor[:, :, :target_length]
174
+ else:
175
+ raise NotImplementedError(f"dim={dim} not supported")
176
+
177
+ # Pad by mirroring
178
+ while tensor.size(dim) < target_length:
179
+ flipped = torch.flip(tensor, [dim])
180
+ tensor = torch.cat([tensor, flipped], dim=dim)
181
+
182
+ if dim == 2:
183
+ tensor = tensor[:, :, :target_length]
184
+
185
+ return tensor
186
+
187
+
188
+ def _generate_warped_noise(
189
+ pass1_video_path: str,
190
+ target_shape: Tuple[int, int, int, int],
191
+ device: torch.device,
192
+ dtype: torch.dtype,
193
+ ) -> torch.Tensor:
194
+ """Generate warped noise from Pass 1 output video.
195
+
196
+ Args:
197
+ pass1_video_path: Path to Pass 1 output video.
198
+ target_shape: (latent_T, latent_H, latent_W, latent_C)
199
+ device: Target device.
200
+ dtype: Target dtype.
201
+
202
+ Returns: (1, T, C, H, W) warped noise tensor.
203
+ """
204
+ # Try to import rp and nw for direct warped noise generation
205
+ try:
206
+ # Fix for SLURM: rp crashes parsing GPU UUIDs like "GPU-9fca2b4f-..."
207
+ # Set CUDA_VISIBLE_DEVICES to numeric index if it contains UUIDs
208
+ cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "")
209
+ if cuda_env and not cuda_env.replace(",", "").isdigit():
210
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
211
+
212
+ import rp
213
+ rp.r._pip_import_autoyes = True
214
+ rp.git_import('CommonSource')
215
+ import rp.git.CommonSource.noise_warp as nw
216
+ return _generate_warped_noise_direct(pass1_video_path, target_shape, device, dtype)
217
+ except ImportError as e:
218
+ print(f"[VOID] rp/noise_warp not available: {e}")
219
+ except Exception as e:
220
+ print(f"[VOID] Warped noise generation via rp failed: {e}")
221
+ import traceback
222
+ traceback.print_exc()
223
+
224
+ # Fallback: try to find and run make_warped_noise.py as subprocess
225
+ script_candidates = [
226
+ os.path.join(os.path.dirname(__file__), "make_warped_noise.py"),
227
+ os.path.join(os.path.dirname(__file__), "..", "inference", "cogvideox_fun", "make_warped_noise.py"),
228
+ ]
229
+ gwf_script = None
230
+ for candidate in script_candidates:
231
+ if os.path.exists(candidate):
232
+ gwf_script = candidate
233
+ break
234
+
235
+ if gwf_script is None:
236
+ raise RuntimeError(
237
+ "Cannot generate warped noise: 'rp' package not installed and "
238
+ "make_warped_noise.py not found. Install 'rp' package or provide "
239
+ "pre-computed warped noise via warped_noise_path parameter."
240
+ )
241
+
242
+ with tempfile.TemporaryDirectory() as tmpdir:
243
+ cmd = [sys.executable, gwf_script, os.path.abspath(pass1_video_path), tmpdir]
244
+ print(f"[VOID] Generating warped noise (this may take a few minutes)...")
245
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
246
+ if result.returncode != 0:
247
+ raise RuntimeError(f"Warped noise generation failed:\n{result.stderr}")
248
+
249
+ # Find the output noises.npy
250
+ video_stem = os.path.splitext(os.path.basename(pass1_video_path))[0]
251
+ noise_path = os.path.join(tmpdir, video_stem, "noises.npy")
252
+ if not os.path.exists(noise_path):
253
+ # Try flat path
254
+ noise_path = os.path.join(tmpdir, "noises.npy")
255
+ if not os.path.exists(noise_path):
256
+ raise RuntimeError(f"Warped noise file not found after generation")
257
+
258
+ return _load_warped_noise(noise_path, target_shape, device, dtype)
259
+
260
+
261
+ def _generate_warped_noise_direct(
262
+ video_path: str,
263
+ target_shape: Tuple[int, int, int, int],
264
+ device: torch.device,
265
+ dtype: torch.dtype,
266
+ ) -> torch.Tensor:
267
+ """Generate warped noise directly using rp package."""
268
+ import rp
269
+ import rp.git.CommonSource.noise_warp as nw
270
+
271
+ video = rp.load_video(video_path)
272
+ video = rp.resize_list(video, length=72)
273
+ video = rp.resize_images_to_hold(video, height=480, width=720)
274
+ video = rp.crop_images(video, height=480, width=720, origin='center')
275
+ video = rp.as_numpy_array(video)
276
+
277
+ FRAME = 2**-1
278
+ FLOW = 2**3
279
+ LATENT = 8
280
+
281
+ output = nw.get_noise_from_video(
282
+ video,
283
+ remove_background=False,
284
+ visualize=False,
285
+ save_files=False,
286
+ noise_channels=16,
287
+ resize_frames=FRAME,
288
+ resize_flow=FLOW,
289
+ downscale_factor=round(FRAME * FLOW) * LATENT,
290
+ )
291
+
292
+ noises = output.numpy_noises # (T, H, W, C)
293
+ return _numpy_noise_to_tensor(noises, target_shape, device, dtype)
294
+
295
+
296
+ def _load_warped_noise(
297
+ noise_path: str,
298
+ target_shape: Tuple[int, int, int, int],
299
+ device: torch.device,
300
+ dtype: torch.dtype,
301
+ ) -> torch.Tensor:
302
+ """Load and resize pre-computed warped noise."""
303
+ noises = np.load(noise_path)
304
+ if noises.dtype == np.float16:
305
+ noises = noises.astype(np.float32)
306
+ # Ensure THWC format
307
+ if noises.shape[1] == 16: # TCHW -> THWC
308
+ noises = np.transpose(noises, (0, 2, 3, 1))
309
+ return _numpy_noise_to_tensor(noises, target_shape, device, dtype)
310
+
311
+
312
+ def _numpy_noise_to_tensor(
313
+ noises: np.ndarray,
314
+ target_shape: Tuple[int, int, int, int],
315
+ device: torch.device,
316
+ dtype: torch.dtype,
317
+ ) -> torch.Tensor:
318
+ """Convert numpy noise (T, H, W, C) to pipeline tensor (1, T, C, H, W)."""
319
+ latent_T, latent_H, latent_W, latent_C = target_shape
320
+
321
+ # Temporal resize if needed
322
+ if noises.shape[0] != latent_T:
323
+ indices = np.linspace(0, noises.shape[0] - 1, latent_T)
324
+ lower = np.floor(indices).astype(int)
325
+ upper = np.ceil(indices).astype(int)
326
+ frac = indices - lower
327
+ noises = noises[lower] * (1 - frac[:, None, None, None]) + noises[upper] * frac[:, None, None, None]
328
+
329
+ # Spatial resize if needed
330
+ if noises.shape[1] != latent_H or noises.shape[2] != latent_W:
331
+ resized = np.zeros((latent_T, latent_H, latent_W, latent_C), dtype=noises.dtype)
332
+ for t in range(latent_T):
333
+ for c in range(latent_C):
334
+ resized[t, :, :, c] = cv2.resize(
335
+ noises[t, :, :, c], (latent_W, latent_H),
336
+ interpolation=cv2.INTER_LINEAR,
337
+ )
338
+ noises = resized
339
+
340
+ # Convert to tensor: (T, H, W, C) -> (1, T, C, H, W)
341
+ tensor = torch.from_numpy(noises).permute(0, 3, 1, 2).unsqueeze(0)
342
+ return tensor.to(device=device, dtype=dtype)
343
+
344
+
345
+ class VOIDPipeline(CogVideoXFunInpaintPipeline):
346
+ """
347
+ VOID: Video Object and Interaction Deletion.
348
+
349
+ Removes objects and their physical interactions from videos using
350
+ quadmask-conditioned video inpainting.
351
+ """
352
+
353
+ @classmethod
354
+ def from_pretrained(
355
+ cls,
356
+ pretrained_model_name_or_path: str,
357
+ void_pass: int = 1,
358
+ base_model: str = BASE_MODEL_REPO,
359
+ torch_dtype: torch.dtype = torch.bfloat16,
360
+ **kwargs,
361
+ ):
362
+ """
363
+ Load the VOID pipeline.
364
+
365
+ Args:
366
+ pretrained_model_name_or_path: HF repo ID or local path containing
367
+ VOID checkpoint files (void_pass1.safetensors, etc.)
368
+ void_pass: Which pass checkpoint to load (1 or 2). Default: 1.
369
+ base_model: HF repo ID for the base CogVideoX-Fun model.
370
+ torch_dtype: Weight dtype. Default: torch.bfloat16.
371
+ """
372
+ if void_pass not in PASS_CHECKPOINTS:
373
+ raise ValueError(f"void_pass must be 1 or 2, got {void_pass}")
374
+
375
+ # --- Download VOID checkpoint ---
376
+ checkpoint_name = PASS_CHECKPOINTS[void_pass]
377
+ print(f"[VOID] Loading Pass {void_pass} checkpoint...")
378
+
379
+ if os.path.isdir(pretrained_model_name_or_path):
380
+ checkpoint_path = os.path.join(pretrained_model_name_or_path, checkpoint_name)
381
+ if not os.path.exists(checkpoint_path):
382
+ # Check parent dir (checkpoints at root, code in diffusers/)
383
+ checkpoint_path = os.path.join(pretrained_model_name_or_path, "..", checkpoint_name)
384
+ else:
385
+ checkpoint_path = hf_hub_download(
386
+ repo_id=pretrained_model_name_or_path,
387
+ filename=checkpoint_name,
388
+ )
389
+
390
+ # --- Download and load base model ---
391
+ print(f"[VOID] Loading base model: {base_model}")
392
+ base_model_path = snapshot_download(repo_id=base_model)
393
+
394
+ # Transformer (with VAE mask channels)
395
+ print("[VOID] Loading transformer...")
396
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
397
+ base_model_path,
398
+ subfolder="transformer",
399
+ low_cpu_mem_usage=True,
400
+ torch_dtype=torch_dtype,
401
+ use_vae_mask=True,
402
+ )
403
+
404
+ # Merge VOID weights
405
+ print(f"[VOID] Merging Pass {void_pass} weights...")
406
+ transformer = _merge_void_weights(transformer, checkpoint_path)
407
+ transformer = transformer.to(torch_dtype)
408
+
409
+ # VAE
410
+ print("[VOID] Loading VAE...")
411
+ vae = AutoencoderKLCogVideoX.from_pretrained(
412
+ base_model_path, subfolder="vae"
413
+ ).to(torch_dtype)
414
+
415
+ # Tokenizer + Text encoder
416
+ print("[VOID] Loading tokenizer and text encoder...")
417
+ from transformers import T5Tokenizer, T5EncoderModel
418
+ tokenizer = T5Tokenizer.from_pretrained(base_model_path, subfolder="tokenizer")
419
+ text_encoder = T5EncoderModel.from_pretrained(
420
+ base_model_path, subfolder="text_encoder", torch_dtype=torch_dtype,
421
+ )
422
+
423
+ # Scheduler
424
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(
425
+ base_model_path, subfolder="scheduler"
426
+ )
427
+
428
+ # Build pipeline
429
+ pipe = cls(
430
+ tokenizer=tokenizer,
431
+ text_encoder=text_encoder,
432
+ vae=vae,
433
+ transformer=transformer,
434
+ scheduler=scheduler,
435
+ )
436
+ pipe._void_pass = void_pass
437
+
438
+ print("[VOID] Pipeline ready!")
439
+ return pipe
440
+
441
+ def inpaint(
442
+ self,
443
+ video_path: str,
444
+ mask_path: str,
445
+ prompt: str,
446
+ negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
447
+ height: int = 384,
448
+ width: int = 672,
449
+ num_inference_steps: int = 30,
450
+ guidance_scale: float = 1.0,
451
+ strength: float = 1.0,
452
+ temporal_window_size: int = 85,
453
+ max_video_length: int = 197,
454
+ fps: int = 12,
455
+ seed: int = 42,
456
+ pass1_video: Optional[str] = None,
457
+ warped_noise_path: Optional[str] = None,
458
+ use_quadmask: bool = True,
459
+ ) -> VOIDOutput:
460
+ """
461
+ Run VOID inpainting on a video.
462
+
463
+ Args:
464
+ video_path: Path to input video (mp4).
465
+ mask_path: Path to quadmask video (mp4). Grayscale with values:
466
+ 0=object to remove, 63=overlap, 127=affected region, 255=background.
467
+ prompt: Text description of the desired result after removal.
468
+ E.g., "A lime falls on the table."
469
+ negative_prompt: Negative prompt for generation quality.
470
+ height: Output height (default 384).
471
+ width: Output width (default 672).
472
+ num_inference_steps: Denoising steps (default 30).
473
+ guidance_scale: CFG scale (default 1.0 = no CFG).
474
+ strength: Denoising strength (default 1.0).
475
+ temporal_window_size: Frames per inference window (default 85).
476
+ max_video_length: Max frames to process (default 197).
477
+ fps: Output FPS (default 12).
478
+ seed: Random seed (default 42).
479
+ pass1_video: Path to Pass 1 output video, for Pass 2 warped noise init.
480
+ warped_noise_path: Path to pre-computed warped noise (.npy).
481
+ use_quadmask: Use 4-value quadmask (default True). Set False for trimask.
482
+
483
+ Returns:
484
+ VOIDOutput with .video (uint8) and .save() method.
485
+ """
486
+ sample_size = (height, width)
487
+
488
+ # Align video length to VAE temporal compression ratio
489
+ vae_temporal_ratio = self.vae.config.temporal_compression_ratio
490
+ video_length = int((max_video_length - 1) // vae_temporal_ratio * vae_temporal_ratio) + 1
491
+
492
+ # --- Load and prep video ---
493
+ print("[VOID] Loading video and mask...")
494
+ vid_np = _load_video(video_path, video_length)
495
+ mask_np = _load_video(mask_path, video_length)
496
+
497
+ video = _prep_video_tensor(vid_np, sample_size)
498
+ mask = _prep_mask_tensor(mask_np, sample_size, use_quadmask=use_quadmask)
499
+
500
+ # Temporal padding
501
+ video = _temporal_padding(video, min_length=temporal_window_size, max_length=max_video_length)
502
+ mask = _temporal_padding(mask, min_length=temporal_window_size, max_length=max_video_length)
503
+
504
+ num_frames = min(video.shape[2], temporal_window_size)
505
+
506
+ print(f"[VOID] Video: {video.shape}, Mask: {mask.shape}, Frames: {num_frames}")
507
+
508
+ # --- Handle warped noise for Pass 2 ---
509
+ latents = None
510
+ if warped_noise_path is not None or pass1_video is not None:
511
+ latent_T = (num_frames - 1) // 4 + 1
512
+ latent_H = height // 8
513
+ latent_W = width // 8
514
+ latent_C = 16
515
+ target_shape = (latent_T, latent_H, latent_W, latent_C)
516
+
517
+ if warped_noise_path is not None:
518
+ print(f"[VOID] Loading pre-computed warped noise from {warped_noise_path}")
519
+ latents = _load_warped_noise(
520
+ warped_noise_path, target_shape,
521
+ device=torch.device("cpu"), dtype=torch.bfloat16,
522
+ )
523
+ else:
524
+ print(f"[VOID] Generating warped noise from Pass 1 output...")
525
+ latents = _generate_warped_noise(
526
+ pass1_video, target_shape,
527
+ device=torch.device("cpu"), dtype=torch.bfloat16,
528
+ )
529
+ print(f"[VOID] Warped noise: {latents.shape}, mean={latents.mean():.4f}, std={latents.std():.4f}")
530
+
531
+ # --- Run inference ---
532
+ generator = torch.Generator(device="cpu").manual_seed(seed)
533
+
534
+ print(f"[VOID] Running inference ({num_frames} frames, {num_inference_steps} steps)...")
535
+ with torch.no_grad():
536
+ output = self(
537
+ prompt=prompt,
538
+ negative_prompt=negative_prompt,
539
+ num_frames=num_frames,
540
+ height=height,
541
+ width=width,
542
+ guidance_scale=guidance_scale,
543
+ num_inference_steps=num_inference_steps,
544
+ generator=generator,
545
+ video=video,
546
+ mask_video=mask,
547
+ strength=strength,
548
+ use_trimask=True,
549
+ use_vae_mask=True,
550
+ latents=latents,
551
+ ).videos
552
+
553
+ # --- Process output ---
554
+ if isinstance(output, np.ndarray):
555
+ output = torch.from_numpy(output)
556
+
557
+ # output is (B, C, T, H, W) in [0, 1]
558
+ video_float = output
559
+ video_uint8 = (output[0].permute(1, 2, 3, 0).clamp(0, 1) * 255).to(torch.uint8)
560
+
561
+ print(f"[VOID] Done! Output: {video_uint8.shape}")
562
+ return VOIDOutput(video=video_uint8, video_float=video_float)