sam-motamed commited on
Commit
f3b6d71
·
verified ·
1 Parent(s): 42ada15

Move cogvideox_vae.py to diffusers/

Browse files
Files changed (1) hide show
  1. cogvideox_vae.py +0 -1675
cogvideox_vae.py DELETED
@@ -1,1675 +0,0 @@
1
- # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Dict, Optional, Tuple, Union
17
-
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
- import torch.nn.functional as F
22
- import json
23
- import os
24
-
25
- from diffusers.configuration_utils import ConfigMixin, register_to_config
26
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
27
- from diffusers.utils import logging
28
- from diffusers.utils.accelerate_utils import apply_forward_hook
29
- from diffusers.models.activations import get_activation
30
- from diffusers.models.downsampling import CogVideoXDownsample3D
31
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
- from diffusers.models.modeling_utils import ModelMixin
33
- from diffusers.models.upsampling import CogVideoXUpsample3D
34
- from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
35
-
36
-
37
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
-
39
-
40
- class CogVideoXSafeConv3d(nn.Conv3d):
41
- r"""
42
- A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
43
- """
44
-
45
- def forward(self, input: torch.Tensor) -> torch.Tensor:
46
- memory_count = (
47
- (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
48
- )
49
-
50
- # Set to 2GB, suitable for CuDNN
51
- if memory_count > 2:
52
- kernel_size = self.kernel_size[0]
53
- part_num = int(memory_count / 2) + 1
54
- input_chunks = torch.chunk(input, part_num, dim=2)
55
-
56
- if kernel_size > 1:
57
- input_chunks = [input_chunks[0]] + [
58
- torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
59
- for i in range(1, len(input_chunks))
60
- ]
61
-
62
- output_chunks = []
63
- for input_chunk in input_chunks:
64
- output_chunks.append(super().forward(input_chunk))
65
- output = torch.cat(output_chunks, dim=2)
66
- return output
67
- else:
68
- return super().forward(input)
69
-
70
-
71
- class CogVideoXCausalConv3d(nn.Module):
72
- r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
73
-
74
- Args:
75
- in_channels (`int`): Number of channels in the input tensor.
76
- out_channels (`int`): Number of output channels produced by the convolution.
77
- kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
78
- stride (`int`, defaults to `1`): Stride of the convolution.
79
- dilation (`int`, defaults to `1`): Dilation rate of the convolution.
80
- pad_mode (`str`, defaults to `"constant"`): Padding mode.
81
- """
82
-
83
- def __init__(
84
- self,
85
- in_channels: int,
86
- out_channels: int,
87
- kernel_size: Union[int, Tuple[int, int, int]],
88
- stride: int = 1,
89
- dilation: int = 1,
90
- pad_mode: str = "constant",
91
- ):
92
- super().__init__()
93
-
94
- if isinstance(kernel_size, int):
95
- kernel_size = (kernel_size,) * 3
96
-
97
- time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
98
-
99
- # TODO(aryan): configure calculation based on stride and dilation in the future.
100
- # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
101
- time_pad = time_kernel_size - 1
102
- height_pad = (height_kernel_size - 1) // 2
103
- width_pad = (width_kernel_size - 1) // 2
104
-
105
- self.pad_mode = pad_mode
106
- self.height_pad = height_pad
107
- self.width_pad = width_pad
108
- self.time_pad = time_pad
109
- self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
110
-
111
- self.temporal_dim = 2
112
- self.time_kernel_size = time_kernel_size
113
-
114
- stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
115
- dilation = (dilation, 1, 1)
116
- self.conv = CogVideoXSafeConv3d(
117
- in_channels=in_channels,
118
- out_channels=out_channels,
119
- kernel_size=kernel_size,
120
- stride=stride,
121
- dilation=dilation,
122
- )
123
-
124
- def fake_context_parallel_forward(
125
- self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
126
- ) -> torch.Tensor:
127
- if self.pad_mode == "replicate":
128
- inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
129
- else:
130
- kernel_size = self.time_kernel_size
131
- if kernel_size > 1:
132
- cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
133
- inputs = torch.cat(cached_inputs + [inputs], dim=2)
134
- return inputs
135
-
136
- def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
137
- inputs = self.fake_context_parallel_forward(inputs, conv_cache)
138
-
139
- if self.pad_mode == "replicate":
140
- conv_cache = None
141
- else:
142
- padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
143
- conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
144
- inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
145
-
146
- output = self.conv(inputs)
147
- return output, conv_cache
148
-
149
-
150
- class CogVideoXSpatialNorm3D(nn.Module):
151
- r"""
152
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
153
- to 3D-video like data.
154
-
155
- CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
156
-
157
- Args:
158
- f_channels (`int`):
159
- The number of channels for input to group normalization layer, and output of the spatial norm layer.
160
- zq_channels (`int`):
161
- The number of channels for the quantized vector as described in the paper.
162
- groups (`int`):
163
- Number of groups to separate the channels into for group normalization.
164
- """
165
-
166
- def __init__(
167
- self,
168
- f_channels: int,
169
- zq_channels: int,
170
- groups: int = 32,
171
- ):
172
- super().__init__()
173
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
174
- self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
175
- self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
176
-
177
- def forward(
178
- self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
179
- ) -> torch.Tensor:
180
- new_conv_cache = {}
181
- conv_cache = conv_cache or {}
182
-
183
- if f.shape[2] > 1 and f.shape[2] % 2 == 1:
184
- f_first, f_rest = f[:, :, :1], f[:, :, 1:]
185
- f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
186
- z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
187
- z_first = F.interpolate(z_first, size=f_first_size)
188
- z_rest = F.interpolate(z_rest, size=f_rest_size)
189
- zq = torch.cat([z_first, z_rest], dim=2)
190
- else:
191
- zq = F.interpolate(zq, size=f.shape[-3:])
192
-
193
- conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
194
- conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
195
-
196
- norm_f = self.norm_layer(f)
197
- new_f = norm_f * conv_y + conv_b
198
- return new_f, new_conv_cache
199
-
200
-
201
- class CogVideoXUpsample3D(nn.Module):
202
- r"""
203
- A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
204
-
205
- Args:
206
- in_channels (`int`):
207
- Number of channels in the input image.
208
- out_channels (`int`):
209
- Number of channels produced by the convolution.
210
- kernel_size (`int`, defaults to `3`):
211
- Size of the convolving kernel.
212
- stride (`int`, defaults to `1`):
213
- Stride of the convolution.
214
- padding (`int`, defaults to `1`):
215
- Padding added to all four sides of the input.
216
- compress_time (`bool`, defaults to `False`):
217
- Whether or not to compress the time dimension.
218
- """
219
-
220
- def __init__(
221
- self,
222
- in_channels: int,
223
- out_channels: int,
224
- kernel_size: int = 3,
225
- stride: int = 1,
226
- padding: int = 1,
227
- compress_time: bool = False,
228
- ) -> None:
229
- super().__init__()
230
-
231
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
232
- self.compress_time = compress_time
233
-
234
- self.auto_split_process = True
235
- self.first_frame_flag = False
236
-
237
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
238
- if self.compress_time:
239
- if self.auto_split_process:
240
- if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
241
- # split first frame
242
- x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
243
-
244
- x_first = F.interpolate(x_first, scale_factor=2.0)
245
- x_rest = F.interpolate(x_rest, scale_factor=2.0)
246
- x_first = x_first[:, :, None, :, :]
247
- inputs = torch.cat([x_first, x_rest], dim=2)
248
- elif inputs.shape[2] > 1:
249
- inputs = F.interpolate(inputs, scale_factor=2.0)
250
- else:
251
- inputs = inputs.squeeze(2)
252
- inputs = F.interpolate(inputs, scale_factor=2.0)
253
- inputs = inputs[:, :, None, :, :]
254
- else:
255
- if self.first_frame_flag:
256
- inputs = inputs.squeeze(2)
257
- inputs = F.interpolate(inputs, scale_factor=2.0)
258
- inputs = inputs[:, :, None, :, :]
259
- else:
260
- inputs = F.interpolate(inputs, scale_factor=2.0)
261
- else:
262
- # only interpolate 2D
263
- b, c, t, h, w = inputs.shape
264
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
265
- inputs = F.interpolate(inputs, scale_factor=2.0)
266
- inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
267
-
268
- b, c, t, h, w = inputs.shape
269
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
270
- inputs = self.conv(inputs)
271
- inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
272
-
273
- return inputs
274
-
275
-
276
- class CogVideoXResnetBlock3D(nn.Module):
277
- r"""
278
- A 3D ResNet block used in the CogVideoX model.
279
-
280
- Args:
281
- in_channels (`int`):
282
- Number of input channels.
283
- out_channels (`int`, *optional*):
284
- Number of output channels. If None, defaults to `in_channels`.
285
- dropout (`float`, defaults to `0.0`):
286
- Dropout rate.
287
- temb_channels (`int`, defaults to `512`):
288
- Number of time embedding channels.
289
- groups (`int`, defaults to `32`):
290
- Number of groups to separate the channels into for group normalization.
291
- eps (`float`, defaults to `1e-6`):
292
- Epsilon value for normalization layers.
293
- non_linearity (`str`, defaults to `"swish"`):
294
- Activation function to use.
295
- conv_shortcut (bool, defaults to `False`):
296
- Whether or not to use a convolution shortcut.
297
- spatial_norm_dim (`int`, *optional*):
298
- The dimension to use for spatial norm if it is to be used instead of group norm.
299
- pad_mode (str, defaults to `"first"`):
300
- Padding mode.
301
- """
302
-
303
- def __init__(
304
- self,
305
- in_channels: int,
306
- out_channels: Optional[int] = None,
307
- dropout: float = 0.0,
308
- temb_channels: int = 512,
309
- groups: int = 32,
310
- eps: float = 1e-6,
311
- non_linearity: str = "swish",
312
- conv_shortcut: bool = False,
313
- spatial_norm_dim: Optional[int] = None,
314
- pad_mode: str = "first",
315
- ):
316
- super().__init__()
317
-
318
- out_channels = out_channels or in_channels
319
-
320
- self.in_channels = in_channels
321
- self.out_channels = out_channels
322
- self.nonlinearity = get_activation(non_linearity)
323
- self.use_conv_shortcut = conv_shortcut
324
- self.spatial_norm_dim = spatial_norm_dim
325
-
326
- if spatial_norm_dim is None:
327
- self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
328
- self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
329
- else:
330
- self.norm1 = CogVideoXSpatialNorm3D(
331
- f_channels=in_channels,
332
- zq_channels=spatial_norm_dim,
333
- groups=groups,
334
- )
335
- self.norm2 = CogVideoXSpatialNorm3D(
336
- f_channels=out_channels,
337
- zq_channels=spatial_norm_dim,
338
- groups=groups,
339
- )
340
-
341
- self.conv1 = CogVideoXCausalConv3d(
342
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
343
- )
344
-
345
- if temb_channels > 0:
346
- self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
347
-
348
- self.dropout = nn.Dropout(dropout)
349
- self.conv2 = CogVideoXCausalConv3d(
350
- in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
351
- )
352
-
353
- if self.in_channels != self.out_channels:
354
- if self.use_conv_shortcut:
355
- self.conv_shortcut = CogVideoXCausalConv3d(
356
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
357
- )
358
- else:
359
- self.conv_shortcut = CogVideoXSafeConv3d(
360
- in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
361
- )
362
-
363
- def forward(
364
- self,
365
- inputs: torch.Tensor,
366
- temb: Optional[torch.Tensor] = None,
367
- zq: Optional[torch.Tensor] = None,
368
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
369
- ) -> torch.Tensor:
370
- new_conv_cache = {}
371
- conv_cache = conv_cache or {}
372
-
373
- hidden_states = inputs
374
-
375
- if zq is not None:
376
- hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
377
- else:
378
- hidden_states = self.norm1(hidden_states)
379
-
380
- hidden_states = self.nonlinearity(hidden_states)
381
- hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
382
-
383
- if temb is not None:
384
- hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
385
-
386
- if zq is not None:
387
- hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
388
- else:
389
- hidden_states = self.norm2(hidden_states)
390
-
391
- hidden_states = self.nonlinearity(hidden_states)
392
- hidden_states = self.dropout(hidden_states)
393
- hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
394
-
395
- if self.in_channels != self.out_channels:
396
- if self.use_conv_shortcut:
397
- inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
398
- inputs, conv_cache=conv_cache.get("conv_shortcut")
399
- )
400
- else:
401
- inputs = self.conv_shortcut(inputs)
402
-
403
- hidden_states = hidden_states + inputs
404
- return hidden_states, new_conv_cache
405
-
406
-
407
- class CogVideoXDownBlock3D(nn.Module):
408
- r"""
409
- A downsampling block used in the CogVideoX model.
410
-
411
- Args:
412
- in_channels (`int`):
413
- Number of input channels.
414
- out_channels (`int`, *optional*):
415
- Number of output channels. If None, defaults to `in_channels`.
416
- temb_channels (`int`, defaults to `512`):
417
- Number of time embedding channels.
418
- num_layers (`int`, defaults to `1`):
419
- Number of resnet layers.
420
- dropout (`float`, defaults to `0.0`):
421
- Dropout rate.
422
- resnet_eps (`float`, defaults to `1e-6`):
423
- Epsilon value for normalization layers.
424
- resnet_act_fn (`str`, defaults to `"swish"`):
425
- Activation function to use.
426
- resnet_groups (`int`, defaults to `32`):
427
- Number of groups to separate the channels into for group normalization.
428
- add_downsample (`bool`, defaults to `True`):
429
- Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
430
- compress_time (`bool`, defaults to `False`):
431
- Whether or not to downsample across temporal dimension.
432
- pad_mode (str, defaults to `"first"`):
433
- Padding mode.
434
- """
435
-
436
- _supports_gradient_checkpointing = True
437
-
438
- def __init__(
439
- self,
440
- in_channels: int,
441
- out_channels: int,
442
- temb_channels: int,
443
- dropout: float = 0.0,
444
- num_layers: int = 1,
445
- resnet_eps: float = 1e-6,
446
- resnet_act_fn: str = "swish",
447
- resnet_groups: int = 32,
448
- add_downsample: bool = True,
449
- downsample_padding: int = 0,
450
- compress_time: bool = False,
451
- pad_mode: str = "first",
452
- ):
453
- super().__init__()
454
-
455
- resnets = []
456
- for i in range(num_layers):
457
- in_channel = in_channels if i == 0 else out_channels
458
- resnets.append(
459
- CogVideoXResnetBlock3D(
460
- in_channels=in_channel,
461
- out_channels=out_channels,
462
- dropout=dropout,
463
- temb_channels=temb_channels,
464
- groups=resnet_groups,
465
- eps=resnet_eps,
466
- non_linearity=resnet_act_fn,
467
- pad_mode=pad_mode,
468
- )
469
- )
470
-
471
- self.resnets = nn.ModuleList(resnets)
472
- self.downsamplers = None
473
-
474
- if add_downsample:
475
- self.downsamplers = nn.ModuleList(
476
- [
477
- CogVideoXDownsample3D(
478
- out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
479
- )
480
- ]
481
- )
482
-
483
- self.gradient_checkpointing = False
484
-
485
- def forward(
486
- self,
487
- hidden_states: torch.Tensor,
488
- temb: Optional[torch.Tensor] = None,
489
- zq: Optional[torch.Tensor] = None,
490
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
491
- ) -> torch.Tensor:
492
- r"""Forward method of the `CogVideoXDownBlock3D` class."""
493
-
494
- new_conv_cache = {}
495
- conv_cache = conv_cache or {}
496
-
497
- for i, resnet in enumerate(self.resnets):
498
- conv_cache_key = f"resnet_{i}"
499
-
500
- if torch.is_grad_enabled() and self.gradient_checkpointing:
501
-
502
- def create_custom_forward(module):
503
- def create_forward(*inputs):
504
- return module(*inputs)
505
-
506
- return create_forward
507
-
508
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
509
- create_custom_forward(resnet),
510
- hidden_states,
511
- temb,
512
- zq,
513
- conv_cache.get(conv_cache_key),
514
- )
515
- else:
516
- hidden_states, new_conv_cache[conv_cache_key] = resnet(
517
- hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
518
- )
519
-
520
- if self.downsamplers is not None:
521
- for downsampler in self.downsamplers:
522
- hidden_states = downsampler(hidden_states)
523
-
524
- return hidden_states, new_conv_cache
525
-
526
-
527
- class CogVideoXMidBlock3D(nn.Module):
528
- r"""
529
- A middle block used in the CogVideoX model.
530
-
531
- Args:
532
- in_channels (`int`):
533
- Number of input channels.
534
- temb_channels (`int`, defaults to `512`):
535
- Number of time embedding channels.
536
- dropout (`float`, defaults to `0.0`):
537
- Dropout rate.
538
- num_layers (`int`, defaults to `1`):
539
- Number of resnet layers.
540
- resnet_eps (`float`, defaults to `1e-6`):
541
- Epsilon value for normalization layers.
542
- resnet_act_fn (`str`, defaults to `"swish"`):
543
- Activation function to use.
544
- resnet_groups (`int`, defaults to `32`):
545
- Number of groups to separate the channels into for group normalization.
546
- spatial_norm_dim (`int`, *optional*):
547
- The dimension to use for spatial norm if it is to be used instead of group norm.
548
- pad_mode (str, defaults to `"first"`):
549
- Padding mode.
550
- """
551
-
552
- _supports_gradient_checkpointing = True
553
-
554
- def __init__(
555
- self,
556
- in_channels: int,
557
- temb_channels: int,
558
- dropout: float = 0.0,
559
- num_layers: int = 1,
560
- resnet_eps: float = 1e-6,
561
- resnet_act_fn: str = "swish",
562
- resnet_groups: int = 32,
563
- spatial_norm_dim: Optional[int] = None,
564
- pad_mode: str = "first",
565
- ):
566
- super().__init__()
567
-
568
- resnets = []
569
- for _ in range(num_layers):
570
- resnets.append(
571
- CogVideoXResnetBlock3D(
572
- in_channels=in_channels,
573
- out_channels=in_channels,
574
- dropout=dropout,
575
- temb_channels=temb_channels,
576
- groups=resnet_groups,
577
- eps=resnet_eps,
578
- spatial_norm_dim=spatial_norm_dim,
579
- non_linearity=resnet_act_fn,
580
- pad_mode=pad_mode,
581
- )
582
- )
583
- self.resnets = nn.ModuleList(resnets)
584
-
585
- self.gradient_checkpointing = False
586
-
587
- def forward(
588
- self,
589
- hidden_states: torch.Tensor,
590
- temb: Optional[torch.Tensor] = None,
591
- zq: Optional[torch.Tensor] = None,
592
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
593
- ) -> torch.Tensor:
594
- r"""Forward method of the `CogVideoXMidBlock3D` class."""
595
-
596
- new_conv_cache = {}
597
- conv_cache = conv_cache or {}
598
-
599
- for i, resnet in enumerate(self.resnets):
600
- conv_cache_key = f"resnet_{i}"
601
-
602
- if torch.is_grad_enabled() and self.gradient_checkpointing:
603
-
604
- def create_custom_forward(module):
605
- def create_forward(*inputs):
606
- return module(*inputs)
607
-
608
- return create_forward
609
-
610
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
611
- create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
612
- )
613
- else:
614
- hidden_states, new_conv_cache[conv_cache_key] = resnet(
615
- hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
616
- )
617
-
618
- return hidden_states, new_conv_cache
619
-
620
-
621
- class CogVideoXUpBlock3D(nn.Module):
622
- r"""
623
- An upsampling block used in the CogVideoX model.
624
-
625
- Args:
626
- in_channels (`int`):
627
- Number of input channels.
628
- out_channels (`int`, *optional*):
629
- Number of output channels. If None, defaults to `in_channels`.
630
- temb_channels (`int`, defaults to `512`):
631
- Number of time embedding channels.
632
- dropout (`float`, defaults to `0.0`):
633
- Dropout rate.
634
- num_layers (`int`, defaults to `1`):
635
- Number of resnet layers.
636
- resnet_eps (`float`, defaults to `1e-6`):
637
- Epsilon value for normalization layers.
638
- resnet_act_fn (`str`, defaults to `"swish"`):
639
- Activation function to use.
640
- resnet_groups (`int`, defaults to `32`):
641
- Number of groups to separate the channels into for group normalization.
642
- spatial_norm_dim (`int`, defaults to `16`):
643
- The dimension to use for spatial norm if it is to be used instead of group norm.
644
- add_upsample (`bool`, defaults to `True`):
645
- Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
646
- compress_time (`bool`, defaults to `False`):
647
- Whether or not to downsample across temporal dimension.
648
- pad_mode (str, defaults to `"first"`):
649
- Padding mode.
650
- """
651
-
652
- def __init__(
653
- self,
654
- in_channels: int,
655
- out_channels: int,
656
- temb_channels: int,
657
- dropout: float = 0.0,
658
- num_layers: int = 1,
659
- resnet_eps: float = 1e-6,
660
- resnet_act_fn: str = "swish",
661
- resnet_groups: int = 32,
662
- spatial_norm_dim: int = 16,
663
- add_upsample: bool = True,
664
- upsample_padding: int = 1,
665
- compress_time: bool = False,
666
- pad_mode: str = "first",
667
- ):
668
- super().__init__()
669
-
670
- resnets = []
671
- for i in range(num_layers):
672
- in_channel = in_channels if i == 0 else out_channels
673
- resnets.append(
674
- CogVideoXResnetBlock3D(
675
- in_channels=in_channel,
676
- out_channels=out_channels,
677
- dropout=dropout,
678
- temb_channels=temb_channels,
679
- groups=resnet_groups,
680
- eps=resnet_eps,
681
- non_linearity=resnet_act_fn,
682
- spatial_norm_dim=spatial_norm_dim,
683
- pad_mode=pad_mode,
684
- )
685
- )
686
-
687
- self.resnets = nn.ModuleList(resnets)
688
- self.upsamplers = None
689
-
690
- if add_upsample:
691
- self.upsamplers = nn.ModuleList(
692
- [
693
- CogVideoXUpsample3D(
694
- out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
695
- )
696
- ]
697
- )
698
-
699
- self.gradient_checkpointing = False
700
-
701
- def forward(
702
- self,
703
- hidden_states: torch.Tensor,
704
- temb: Optional[torch.Tensor] = None,
705
- zq: Optional[torch.Tensor] = None,
706
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
707
- ) -> torch.Tensor:
708
- r"""Forward method of the `CogVideoXUpBlock3D` class."""
709
-
710
- new_conv_cache = {}
711
- conv_cache = conv_cache or {}
712
-
713
- for i, resnet in enumerate(self.resnets):
714
- conv_cache_key = f"resnet_{i}"
715
-
716
- if torch.is_grad_enabled() and self.gradient_checkpointing:
717
-
718
- def create_custom_forward(module):
719
- def create_forward(*inputs):
720
- return module(*inputs)
721
-
722
- return create_forward
723
-
724
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
725
- create_custom_forward(resnet),
726
- hidden_states,
727
- temb,
728
- zq,
729
- conv_cache.get(conv_cache_key),
730
- )
731
- else:
732
- hidden_states, new_conv_cache[conv_cache_key] = resnet(
733
- hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
734
- )
735
-
736
- if self.upsamplers is not None:
737
- for upsampler in self.upsamplers:
738
- hidden_states = upsampler(hidden_states)
739
-
740
- return hidden_states, new_conv_cache
741
-
742
-
743
- class CogVideoXEncoder3D(nn.Module):
744
- r"""
745
- The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
746
-
747
- Args:
748
- in_channels (`int`, *optional*, defaults to 3):
749
- The number of input channels.
750
- out_channels (`int`, *optional*, defaults to 3):
751
- The number of output channels.
752
- down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
753
- The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
754
- options.
755
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
756
- The number of output channels for each block.
757
- act_fn (`str`, *optional*, defaults to `"silu"`):
758
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
759
- layers_per_block (`int`, *optional*, defaults to 2):
760
- The number of layers per block.
761
- norm_num_groups (`int`, *optional*, defaults to 32):
762
- The number of groups for normalization.
763
- """
764
-
765
- _supports_gradient_checkpointing = True
766
-
767
- def __init__(
768
- self,
769
- in_channels: int = 3,
770
- out_channels: int = 16,
771
- down_block_types: Tuple[str, ...] = (
772
- "CogVideoXDownBlock3D",
773
- "CogVideoXDownBlock3D",
774
- "CogVideoXDownBlock3D",
775
- "CogVideoXDownBlock3D",
776
- ),
777
- block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
778
- layers_per_block: int = 3,
779
- act_fn: str = "silu",
780
- norm_eps: float = 1e-6,
781
- norm_num_groups: int = 32,
782
- dropout: float = 0.0,
783
- pad_mode: str = "first",
784
- temporal_compression_ratio: float = 4,
785
- ):
786
- super().__init__()
787
-
788
- # log2 of temporal_compress_times
789
- temporal_compress_level = int(np.log2(temporal_compression_ratio))
790
-
791
- self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
792
- self.down_blocks = nn.ModuleList([])
793
-
794
- # down blocks
795
- output_channel = block_out_channels[0]
796
- for i, down_block_type in enumerate(down_block_types):
797
- input_channel = output_channel
798
- output_channel = block_out_channels[i]
799
- is_final_block = i == len(block_out_channels) - 1
800
- compress_time = i < temporal_compress_level
801
-
802
- if down_block_type == "CogVideoXDownBlock3D":
803
- down_block = CogVideoXDownBlock3D(
804
- in_channels=input_channel,
805
- out_channels=output_channel,
806
- temb_channels=0,
807
- dropout=dropout,
808
- num_layers=layers_per_block,
809
- resnet_eps=norm_eps,
810
- resnet_act_fn=act_fn,
811
- resnet_groups=norm_num_groups,
812
- add_downsample=not is_final_block,
813
- compress_time=compress_time,
814
- )
815
- else:
816
- raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
817
-
818
- self.down_blocks.append(down_block)
819
-
820
- # mid block
821
- self.mid_block = CogVideoXMidBlock3D(
822
- in_channels=block_out_channels[-1],
823
- temb_channels=0,
824
- dropout=dropout,
825
- num_layers=2,
826
- resnet_eps=norm_eps,
827
- resnet_act_fn=act_fn,
828
- resnet_groups=norm_num_groups,
829
- pad_mode=pad_mode,
830
- )
831
-
832
- self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
833
- self.conv_act = nn.SiLU()
834
- self.conv_out = CogVideoXCausalConv3d(
835
- block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
836
- )
837
-
838
- self.gradient_checkpointing = False
839
-
840
- def forward(
841
- self,
842
- sample: torch.Tensor,
843
- temb: Optional[torch.Tensor] = None,
844
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
845
- ) -> torch.Tensor:
846
- r"""The forward method of the `CogVideoXEncoder3D` class."""
847
-
848
- new_conv_cache = {}
849
- conv_cache = conv_cache or {}
850
-
851
- hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
-
853
- if torch.is_grad_enabled() and self.gradient_checkpointing:
854
-
855
- def create_custom_forward(module):
856
- def custom_forward(*inputs):
857
- return module(*inputs)
858
-
859
- return custom_forward
860
-
861
- # 1. Down
862
- for i, down_block in enumerate(self.down_blocks):
863
- conv_cache_key = f"down_block_{i}"
864
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
865
- create_custom_forward(down_block),
866
- hidden_states,
867
- temb,
868
- None,
869
- conv_cache.get(conv_cache_key),
870
- )
871
-
872
- # 2. Mid
873
- hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
874
- create_custom_forward(self.mid_block),
875
- hidden_states,
876
- temb,
877
- None,
878
- conv_cache.get("mid_block"),
879
- )
880
- else:
881
- # 1. Down
882
- for i, down_block in enumerate(self.down_blocks):
883
- conv_cache_key = f"down_block_{i}"
884
- hidden_states, new_conv_cache[conv_cache_key] = down_block(
885
- hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
886
- )
887
-
888
- # 2. Mid
889
- hidden_states, new_conv_cache["mid_block"] = self.mid_block(
890
- hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
891
- )
892
-
893
- # 3. Post-process
894
- hidden_states = self.norm_out(hidden_states)
895
- hidden_states = self.conv_act(hidden_states)
896
-
897
- hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
898
-
899
- return hidden_states, new_conv_cache
900
-
901
-
902
- class CogVideoXDecoder3D(nn.Module):
903
- r"""
904
- The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
905
- sample.
906
-
907
- Args:
908
- in_channels (`int`, *optional*, defaults to 3):
909
- The number of input channels.
910
- out_channels (`int`, *optional*, defaults to 3):
911
- The number of output channels.
912
- up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
913
- The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
914
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
915
- The number of output channels for each block.
916
- act_fn (`str`, *optional*, defaults to `"silu"`):
917
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
918
- layers_per_block (`int`, *optional*, defaults to 2):
919
- The number of layers per block.
920
- norm_num_groups (`int`, *optional*, defaults to 32):
921
- The number of groups for normalization.
922
- """
923
-
924
- _supports_gradient_checkpointing = True
925
-
926
- def __init__(
927
- self,
928
- in_channels: int = 16,
929
- out_channels: int = 3,
930
- up_block_types: Tuple[str, ...] = (
931
- "CogVideoXUpBlock3D",
932
- "CogVideoXUpBlock3D",
933
- "CogVideoXUpBlock3D",
934
- "CogVideoXUpBlock3D",
935
- ),
936
- block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
937
- layers_per_block: int = 3,
938
- act_fn: str = "silu",
939
- norm_eps: float = 1e-6,
940
- norm_num_groups: int = 32,
941
- dropout: float = 0.0,
942
- pad_mode: str = "first",
943
- temporal_compression_ratio: float = 4,
944
- ):
945
- super().__init__()
946
-
947
- reversed_block_out_channels = list(reversed(block_out_channels))
948
-
949
- self.conv_in = CogVideoXCausalConv3d(
950
- in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
951
- )
952
-
953
- # mid block
954
- self.mid_block = CogVideoXMidBlock3D(
955
- in_channels=reversed_block_out_channels[0],
956
- temb_channels=0,
957
- num_layers=2,
958
- resnet_eps=norm_eps,
959
- resnet_act_fn=act_fn,
960
- resnet_groups=norm_num_groups,
961
- spatial_norm_dim=in_channels,
962
- pad_mode=pad_mode,
963
- )
964
-
965
- # up blocks
966
- self.up_blocks = nn.ModuleList([])
967
-
968
- output_channel = reversed_block_out_channels[0]
969
- temporal_compress_level = int(np.log2(temporal_compression_ratio))
970
-
971
- for i, up_block_type in enumerate(up_block_types):
972
- prev_output_channel = output_channel
973
- output_channel = reversed_block_out_channels[i]
974
- is_final_block = i == len(block_out_channels) - 1
975
- compress_time = i < temporal_compress_level
976
-
977
- if up_block_type == "CogVideoXUpBlock3D":
978
- up_block = CogVideoXUpBlock3D(
979
- in_channels=prev_output_channel,
980
- out_channels=output_channel,
981
- temb_channels=0,
982
- dropout=dropout,
983
- num_layers=layers_per_block + 1,
984
- resnet_eps=norm_eps,
985
- resnet_act_fn=act_fn,
986
- resnet_groups=norm_num_groups,
987
- spatial_norm_dim=in_channels,
988
- add_upsample=not is_final_block,
989
- compress_time=compress_time,
990
- pad_mode=pad_mode,
991
- )
992
- prev_output_channel = output_channel
993
- else:
994
- raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
995
-
996
- self.up_blocks.append(up_block)
997
-
998
- self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
999
- self.conv_act = nn.SiLU()
1000
- self.conv_out = CogVideoXCausalConv3d(
1001
- reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
1002
- )
1003
-
1004
- self.gradient_checkpointing = False
1005
-
1006
- def forward(
1007
- self,
1008
- sample: torch.Tensor,
1009
- temb: Optional[torch.Tensor] = None,
1010
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
1011
- ) -> torch.Tensor:
1012
- r"""The forward method of the `CogVideoXDecoder3D` class."""
1013
-
1014
- new_conv_cache = {}
1015
- conv_cache = conv_cache or {}
1016
-
1017
- hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
1018
-
1019
- if torch.is_grad_enabled() and self.gradient_checkpointing:
1020
-
1021
- def create_custom_forward(module):
1022
- def custom_forward(*inputs):
1023
- return module(*inputs)
1024
-
1025
- return custom_forward
1026
-
1027
- # 1. Mid
1028
- hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
1029
- create_custom_forward(self.mid_block),
1030
- hidden_states,
1031
- temb,
1032
- sample,
1033
- conv_cache.get("mid_block"),
1034
- )
1035
-
1036
- # 2. Up
1037
- for i, up_block in enumerate(self.up_blocks):
1038
- conv_cache_key = f"up_block_{i}"
1039
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
1040
- create_custom_forward(up_block),
1041
- hidden_states,
1042
- temb,
1043
- sample,
1044
- conv_cache.get(conv_cache_key),
1045
- )
1046
- else:
1047
- # 1. Mid
1048
- hidden_states, new_conv_cache["mid_block"] = self.mid_block(
1049
- hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
1050
- )
1051
-
1052
- # 2. Up
1053
- for i, up_block in enumerate(self.up_blocks):
1054
- conv_cache_key = f"up_block_{i}"
1055
- hidden_states, new_conv_cache[conv_cache_key] = up_block(
1056
- hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
1057
- )
1058
-
1059
- # 3. Post-process
1060
- hidden_states, new_conv_cache["norm_out"] = self.norm_out(
1061
- hidden_states, sample, conv_cache=conv_cache.get("norm_out")
1062
- )
1063
- hidden_states = self.conv_act(hidden_states)
1064
- hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
1065
-
1066
- return hidden_states, new_conv_cache
1067
-
1068
-
1069
- class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1070
- r"""
1071
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
1072
- [CogVideoX](https://github.com/THUDM/CogVideo).
1073
-
1074
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1075
- for all models (such as downloading or saving).
1076
-
1077
- Parameters:
1078
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1079
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1080
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1081
- Tuple of downsample block types.
1082
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1083
- Tuple of upsample block types.
1084
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1085
- Tuple of block output channels.
1086
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1087
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1088
- scaling_factor (`float`, *optional*, defaults to `1.15258426`):
1089
- The component-wise standard deviation of the trained latent space computed using the first batch of the
1090
- training set. This is used to scale the latent space to have unit variance when training the diffusion
1091
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
1092
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
1093
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
1094
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
1095
- force_upcast (`bool`, *optional*, default to `True`):
1096
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1097
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
1098
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1099
- """
1100
-
1101
- _supports_gradient_checkpointing = True
1102
- _no_split_modules = ["CogVideoXResnetBlock3D"]
1103
-
1104
- @register_to_config
1105
- def __init__(
1106
- self,
1107
- in_channels: int = 3,
1108
- out_channels: int = 3,
1109
- down_block_types: Tuple[str] = (
1110
- "CogVideoXDownBlock3D",
1111
- "CogVideoXDownBlock3D",
1112
- "CogVideoXDownBlock3D",
1113
- "CogVideoXDownBlock3D",
1114
- ),
1115
- up_block_types: Tuple[str] = (
1116
- "CogVideoXUpBlock3D",
1117
- "CogVideoXUpBlock3D",
1118
- "CogVideoXUpBlock3D",
1119
- "CogVideoXUpBlock3D",
1120
- ),
1121
- block_out_channels: Tuple[int] = (128, 256, 256, 512),
1122
- latent_channels: int = 16,
1123
- layers_per_block: int = 3,
1124
- act_fn: str = "silu",
1125
- norm_eps: float = 1e-6,
1126
- norm_num_groups: int = 32,
1127
- temporal_compression_ratio: float = 4,
1128
- sample_height: int = 480,
1129
- sample_width: int = 720,
1130
- scaling_factor: float = 1.15258426,
1131
- shift_factor: Optional[float] = None,
1132
- latents_mean: Optional[Tuple[float]] = None,
1133
- latents_std: Optional[Tuple[float]] = None,
1134
- force_upcast: float = True,
1135
- use_quant_conv: bool = False,
1136
- use_post_quant_conv: bool = False,
1137
- invert_scale_latents: bool = False,
1138
- ):
1139
- super().__init__()
1140
-
1141
- self.encoder = CogVideoXEncoder3D(
1142
- in_channels=in_channels,
1143
- out_channels=latent_channels,
1144
- down_block_types=down_block_types,
1145
- block_out_channels=block_out_channels,
1146
- layers_per_block=layers_per_block,
1147
- act_fn=act_fn,
1148
- norm_eps=norm_eps,
1149
- norm_num_groups=norm_num_groups,
1150
- temporal_compression_ratio=temporal_compression_ratio,
1151
- )
1152
- self.decoder = CogVideoXDecoder3D(
1153
- in_channels=latent_channels,
1154
- out_channels=out_channels,
1155
- up_block_types=up_block_types,
1156
- block_out_channels=block_out_channels,
1157
- layers_per_block=layers_per_block,
1158
- act_fn=act_fn,
1159
- norm_eps=norm_eps,
1160
- norm_num_groups=norm_num_groups,
1161
- temporal_compression_ratio=temporal_compression_ratio,
1162
- )
1163
- self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
1164
- self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
1165
-
1166
- self.use_slicing = False
1167
- self.use_tiling = False
1168
- self.auto_split_process = False
1169
-
1170
- # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
1171
- # recommended because the temporal parts of the VAE, here, are tricky to understand.
1172
- # If you decode X latent frames together, the number of output frames is:
1173
- # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
1174
- #
1175
- # Example with num_latent_frames_batch_size = 2:
1176
- # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
1177
- # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1178
- # => 6 * 8 = 48 frames
1179
- # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
1180
- # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
1181
- # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1182
- # => 1 * 9 + 5 * 8 = 49 frames
1183
- # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
1184
- # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1185
- # number of temporal frames.
1186
- self.num_latent_frames_batch_size = 2
1187
- self.num_sample_frames_batch_size = 8
1188
-
1189
- # We make the minimum height and width of sample for tiling half that of the generally supported
1190
- self.tile_sample_min_height = sample_height // 2
1191
- self.tile_sample_min_width = sample_width // 2
1192
- self.tile_latent_min_height = int(
1193
- self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1194
- )
1195
- self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1196
-
1197
- # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1198
- # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1199
- # and so the tiling implementation has only been tested on those specific resolutions.
1200
- self.tile_overlap_factor_height = 1 / 6
1201
- self.tile_overlap_factor_width = 1 / 5
1202
-
1203
- def _set_gradient_checkpointing(self, module, value=False):
1204
- if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1205
- module.gradient_checkpointing = value
1206
-
1207
- def enable_tiling(
1208
- self,
1209
- tile_sample_min_height: Optional[int] = None,
1210
- tile_sample_min_width: Optional[int] = None,
1211
- tile_overlap_factor_height: Optional[float] = None,
1212
- tile_overlap_factor_width: Optional[float] = None,
1213
- ) -> None:
1214
- r"""
1215
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1216
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1217
- processing larger images.
1218
-
1219
- Args:
1220
- tile_sample_min_height (`int`, *optional*):
1221
- The minimum height required for a sample to be separated into tiles across the height dimension.
1222
- tile_sample_min_width (`int`, *optional*):
1223
- The minimum width required for a sample to be separated into tiles across the width dimension.
1224
- tile_overlap_factor_height (`int`, *optional*):
1225
- The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1226
- no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1227
- value might cause more tiles to be processed leading to slow down of the decoding process.
1228
- tile_overlap_factor_width (`int`, *optional*):
1229
- The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1230
- are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1231
- value might cause more tiles to be processed leading to slow down of the decoding process.
1232
- """
1233
- self.use_tiling = True
1234
- self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1235
- self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1236
- self.tile_latent_min_height = int(
1237
- self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1238
- )
1239
- self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1240
- self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1241
- self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1242
-
1243
- def disable_tiling(self) -> None:
1244
- r"""
1245
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1246
- decoding in one step.
1247
- """
1248
- self.use_tiling = False
1249
-
1250
- def enable_slicing(self) -> None:
1251
- r"""
1252
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1253
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1254
- """
1255
- self.use_slicing = True
1256
-
1257
- def disable_slicing(self) -> None:
1258
- r"""
1259
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1260
- decoding in one step.
1261
- """
1262
- self.use_slicing = False
1263
-
1264
- def _set_first_frame(self):
1265
- for name, module in self.named_modules():
1266
- if isinstance(module, CogVideoXUpsample3D):
1267
- module.auto_split_process = False
1268
- module.first_frame_flag = True
1269
-
1270
- def _set_rest_frame(self):
1271
- for name, module in self.named_modules():
1272
- if isinstance(module, CogVideoXUpsample3D):
1273
- module.auto_split_process = False
1274
- module.first_frame_flag = False
1275
-
1276
- def enable_auto_split_process(self) -> None:
1277
- self.auto_split_process = True
1278
- for name, module in self.named_modules():
1279
- if isinstance(module, CogVideoXUpsample3D):
1280
- module.auto_split_process = True
1281
-
1282
- def disable_auto_split_process(self) -> None:
1283
- self.auto_split_process = False
1284
-
1285
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
1286
- batch_size, num_channels, num_frames, height, width = x.shape
1287
-
1288
- if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1289
- return self.tiled_encode(x)
1290
-
1291
- frame_batch_size = self.num_sample_frames_batch_size
1292
- # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1293
- # As the extra single frame is handled inside the loop, it is not required to round up here.
1294
- num_batches = max(num_frames // frame_batch_size, 1)
1295
- conv_cache = None
1296
- enc = []
1297
-
1298
- for i in range(num_batches):
1299
- remaining_frames = num_frames % frame_batch_size
1300
- start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1301
- end_frame = frame_batch_size * (i + 1) + remaining_frames
1302
- x_intermediate = x[:, :, start_frame:end_frame]
1303
- x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1304
- if self.quant_conv is not None:
1305
- x_intermediate = self.quant_conv(x_intermediate)
1306
- enc.append(x_intermediate)
1307
-
1308
- enc = torch.cat(enc, dim=2)
1309
- return enc
1310
-
1311
- @apply_forward_hook
1312
- def encode(
1313
- self, x: torch.Tensor, return_dict: bool = True
1314
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1315
- """
1316
- Encode a batch of images into latents.
1317
-
1318
- Args:
1319
- x (`torch.Tensor`): Input batch of images.
1320
- return_dict (`bool`, *optional*, defaults to `True`):
1321
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1322
-
1323
- Returns:
1324
- The latent representations of the encoded videos. If `return_dict` is True, a
1325
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1326
- """
1327
- if self.use_slicing and x.shape[0] > 1:
1328
- encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1329
- h = torch.cat(encoded_slices)
1330
- else:
1331
- h = self._encode(x)
1332
-
1333
- posterior = DiagonalGaussianDistribution(h)
1334
-
1335
- if not return_dict:
1336
- return (posterior,)
1337
- return AutoencoderKLOutput(latent_dist=posterior)
1338
-
1339
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1340
- batch_size, num_channels, num_frames, height, width = z.shape
1341
-
1342
- if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1343
- return self.tiled_decode(z, return_dict=return_dict)
1344
-
1345
- if self.auto_split_process:
1346
- frame_batch_size = self.num_latent_frames_batch_size
1347
- num_batches = max(num_frames // frame_batch_size, 1)
1348
- conv_cache = None
1349
- dec = []
1350
-
1351
- for i in range(num_batches):
1352
- remaining_frames = num_frames % frame_batch_size
1353
- start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1354
- end_frame = frame_batch_size * (i + 1) + remaining_frames
1355
- z_intermediate = z[:, :, start_frame:end_frame]
1356
- if self.post_quant_conv is not None:
1357
- z_intermediate = self.post_quant_conv(z_intermediate)
1358
- z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1359
- dec.append(z_intermediate)
1360
- else:
1361
- conv_cache = None
1362
- start_frame = 0
1363
- end_frame = 1
1364
- dec = []
1365
-
1366
- self._set_first_frame()
1367
- z_intermediate = z[:, :, start_frame:end_frame]
1368
- if self.post_quant_conv is not None:
1369
- z_intermediate = self.post_quant_conv(z_intermediate)
1370
- z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1371
- dec.append(z_intermediate)
1372
-
1373
- self._set_rest_frame()
1374
- start_frame = end_frame
1375
- end_frame += self.num_latent_frames_batch_size
1376
-
1377
- while start_frame < num_frames:
1378
- z_intermediate = z[:, :, start_frame:end_frame]
1379
- if self.post_quant_conv is not None:
1380
- z_intermediate = self.post_quant_conv(z_intermediate)
1381
- z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1382
- dec.append(z_intermediate)
1383
- start_frame = end_frame
1384
- end_frame += self.num_latent_frames_batch_size
1385
-
1386
- dec = torch.cat(dec, dim=2)
1387
-
1388
- if not return_dict:
1389
- return (dec,)
1390
-
1391
- return DecoderOutput(sample=dec)
1392
-
1393
- @apply_forward_hook
1394
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1395
- """
1396
- Decode a batch of images.
1397
-
1398
- Args:
1399
- z (`torch.Tensor`): Input batch of latent vectors.
1400
- return_dict (`bool`, *optional*, defaults to `True`):
1401
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1402
-
1403
- Returns:
1404
- [`~models.vae.DecoderOutput`] or `tuple`:
1405
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1406
- returned.
1407
- """
1408
- if self.use_slicing and z.shape[0] > 1:
1409
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1410
- decoded = torch.cat(decoded_slices)
1411
- else:
1412
- decoded = self._decode(z).sample
1413
-
1414
- if not return_dict:
1415
- return (decoded,)
1416
- return DecoderOutput(sample=decoded)
1417
-
1418
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1419
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1420
- for y in range(blend_extent):
1421
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1422
- y / blend_extent
1423
- )
1424
- return b
1425
-
1426
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1427
- blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1428
- for x in range(blend_extent):
1429
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1430
- x / blend_extent
1431
- )
1432
- return b
1433
-
1434
- def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1435
- r"""Encode a batch of images using a tiled encoder.
1436
-
1437
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1438
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1439
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1440
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1441
- output, but they should be much less noticeable.
1442
-
1443
- Args:
1444
- x (`torch.Tensor`): Input batch of videos.
1445
-
1446
- Returns:
1447
- `torch.Tensor`:
1448
- The latent representation of the encoded videos.
1449
- """
1450
- # For a rough memory estimate, take a look at the `tiled_decode` method.
1451
- batch_size, num_channels, num_frames, height, width = x.shape
1452
-
1453
- overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1454
- overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1455
- blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1456
- blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1457
- row_limit_height = self.tile_latent_min_height - blend_extent_height
1458
- row_limit_width = self.tile_latent_min_width - blend_extent_width
1459
- frame_batch_size = self.num_sample_frames_batch_size
1460
-
1461
- # Split x into overlapping tiles and encode them separately.
1462
- # The tiles have an overlap to avoid seams between tiles.
1463
- rows = []
1464
- for i in range(0, height, overlap_height):
1465
- row = []
1466
- for j in range(0, width, overlap_width):
1467
- # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1468
- # As the extra single frame is handled inside the loop, it is not required to round up here.
1469
- num_batches = max(num_frames // frame_batch_size, 1)
1470
- conv_cache = None
1471
- time = []
1472
-
1473
- for k in range(num_batches):
1474
- remaining_frames = num_frames % frame_batch_size
1475
- start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1476
- end_frame = frame_batch_size * (k + 1) + remaining_frames
1477
- tile = x[
1478
- :,
1479
- :,
1480
- start_frame:end_frame,
1481
- i : i + self.tile_sample_min_height,
1482
- j : j + self.tile_sample_min_width,
1483
- ]
1484
- tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1485
- if self.quant_conv is not None:
1486
- tile = self.quant_conv(tile)
1487
- time.append(tile)
1488
-
1489
- row.append(torch.cat(time, dim=2))
1490
- rows.append(row)
1491
-
1492
- result_rows = []
1493
- for i, row in enumerate(rows):
1494
- result_row = []
1495
- for j, tile in enumerate(row):
1496
- # blend the above tile and the left tile
1497
- # to the current tile and add the current tile to the result row
1498
- if i > 0:
1499
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1500
- if j > 0:
1501
- tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1502
- result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1503
- result_rows.append(torch.cat(result_row, dim=4))
1504
-
1505
- enc = torch.cat(result_rows, dim=3)
1506
- return enc
1507
-
1508
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1509
- r"""
1510
- Decode a batch of images using a tiled decoder.
1511
-
1512
- Args:
1513
- z (`torch.Tensor`): Input batch of latent vectors.
1514
- return_dict (`bool`, *optional*, defaults to `True`):
1515
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1516
-
1517
- Returns:
1518
- [`~models.vae.DecoderOutput`] or `tuple`:
1519
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1520
- returned.
1521
- """
1522
- # Rough memory assessment:
1523
- # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1524
- # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1525
- # - Assume fp16 (2 bytes per value).
1526
- # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1527
- #
1528
- # Memory assessment when using tiling:
1529
- # - Assume everything as above but now HxW is 240x360 by tiling in half
1530
- # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1531
-
1532
- batch_size, num_channels, num_frames, height, width = z.shape
1533
-
1534
- overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1535
- overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1536
- blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1537
- blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1538
- row_limit_height = self.tile_sample_min_height - blend_extent_height
1539
- row_limit_width = self.tile_sample_min_width - blend_extent_width
1540
- frame_batch_size = self.num_latent_frames_batch_size
1541
-
1542
- # Split z into overlapping tiles and decode them separately.
1543
- # The tiles have an overlap to avoid seams between tiles.
1544
- rows = []
1545
- for i in range(0, height, overlap_height):
1546
- row = []
1547
- for j in range(0, width, overlap_width):
1548
- if self.auto_split_process:
1549
- num_batches = max(num_frames // frame_batch_size, 1)
1550
- conv_cache = None
1551
- time = []
1552
-
1553
- for k in range(num_batches):
1554
- remaining_frames = num_frames % frame_batch_size
1555
- start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1556
- end_frame = frame_batch_size * (k + 1) + remaining_frames
1557
- tile = z[
1558
- :,
1559
- :,
1560
- start_frame:end_frame,
1561
- i : i + self.tile_latent_min_height,
1562
- j : j + self.tile_latent_min_width,
1563
- ]
1564
- if self.post_quant_conv is not None:
1565
- tile = self.post_quant_conv(tile)
1566
- tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1567
- time.append(tile)
1568
-
1569
- row.append(torch.cat(time, dim=2))
1570
- else:
1571
- conv_cache = None
1572
- start_frame = 0
1573
- end_frame = 1
1574
- dec = []
1575
-
1576
- tile = z[
1577
- :,
1578
- :,
1579
- start_frame:end_frame,
1580
- i : i + self.tile_latent_min_height,
1581
- j : j + self.tile_latent_min_width,
1582
- ]
1583
-
1584
- self._set_first_frame()
1585
- if self.post_quant_conv is not None:
1586
- tile = self.post_quant_conv(tile)
1587
- tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1588
- dec.append(tile)
1589
-
1590
- self._set_rest_frame()
1591
- start_frame = end_frame
1592
- end_frame += self.num_latent_frames_batch_size
1593
-
1594
- while start_frame < num_frames:
1595
- tile = z[
1596
- :,
1597
- :,
1598
- start_frame:end_frame,
1599
- i : i + self.tile_latent_min_height,
1600
- j : j + self.tile_latent_min_width,
1601
- ]
1602
- if self.post_quant_conv is not None:
1603
- tile = self.post_quant_conv(tile)
1604
- tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1605
- dec.append(tile)
1606
- start_frame = end_frame
1607
- end_frame += self.num_latent_frames_batch_size
1608
-
1609
- row.append(torch.cat(dec, dim=2))
1610
- rows.append(row)
1611
-
1612
- result_rows = []
1613
- for i, row in enumerate(rows):
1614
- result_row = []
1615
- for j, tile in enumerate(row):
1616
- # blend the above tile and the left tile
1617
- # to the current tile and add the current tile to the result row
1618
- if i > 0:
1619
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1620
- if j > 0:
1621
- tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1622
- result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1623
- result_rows.append(torch.cat(result_row, dim=4))
1624
-
1625
- dec = torch.cat(result_rows, dim=3)
1626
-
1627
- if not return_dict:
1628
- return (dec,)
1629
-
1630
- return DecoderOutput(sample=dec)
1631
-
1632
- def forward(
1633
- self,
1634
- sample: torch.Tensor,
1635
- sample_posterior: bool = False,
1636
- return_dict: bool = True,
1637
- generator: Optional[torch.Generator] = None,
1638
- ) -> Union[torch.Tensor, torch.Tensor]:
1639
- x = sample
1640
- posterior = self.encode(x).latent_dist
1641
- if sample_posterior:
1642
- z = posterior.sample(generator=generator)
1643
- else:
1644
- z = posterior.mode()
1645
- dec = self.decode(z)
1646
- if not return_dict:
1647
- return (dec,)
1648
- return dec
1649
-
1650
- @classmethod
1651
- def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
1652
- if subfolder is not None:
1653
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1654
-
1655
- config_file = os.path.join(pretrained_model_path, 'config.json')
1656
- if not os.path.isfile(config_file):
1657
- raise RuntimeError(f"{config_file} does not exist")
1658
- with open(config_file, "r") as f:
1659
- config = json.load(f)
1660
-
1661
- model = cls.from_config(config, **vae_additional_kwargs)
1662
- from diffusers.utils import WEIGHTS_NAME
1663
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1664
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
1665
- if os.path.exists(model_file_safetensors):
1666
- from safetensors.torch import load_file, safe_open
1667
- state_dict = load_file(model_file_safetensors)
1668
- else:
1669
- if not os.path.isfile(model_file):
1670
- raise RuntimeError(f"{model_file} does not exist")
1671
- state_dict = torch.load(model_file, map_location="cpu")
1672
- m, u = model.load_state_dict(state_dict, strict=False)
1673
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1674
- print(m, u)
1675
- return model