sam-motamed commited on
Commit
bad41bb
·
verified ·
1 Parent(s): e3ea692

Upload 51 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. app.py +278 -4
  3. requirements.txt +61 -0
  4. sample/lime/first_frame.jpg +3 -0
  5. sample/lime/input_video.mp4 +3 -0
  6. sample/lime/prompt.json +3 -0
  7. sample/lime/quadmask_0.mp4 +3 -0
  8. sample/lime/segmentation_info.json +221 -0
  9. sample/moving_ball/first_frame.jpg +3 -0
  10. sample/moving_ball/input_video.mp4 +3 -0
  11. sample/moving_ball/prompt.json +3 -0
  12. sample/moving_ball/quadmask_0.mp4 +3 -0
  13. sample/pillow/input_video.mp4 +3 -0
  14. sample/pillow/prompt.json +3 -0
  15. sample/pillow/quadmask_0.mp4 +3 -0
  16. sample/pillow/segmentation_info.json +85 -0
  17. videox_fun/__init__.py +0 -0
  18. videox_fun/api/api.py +213 -0
  19. videox_fun/api/api_multi_nodes.py +215 -0
  20. videox_fun/data/bucket_sampler.py +390 -0
  21. videox_fun/data/dataset_image.py +76 -0
  22. videox_fun/data/dataset_image_video.py +1067 -0
  23. videox_fun/data/dataset_image_video_warped.py +1092 -0
  24. videox_fun/data/dataset_video.py +262 -0
  25. videox_fun/dist/__init__.py +40 -0
  26. videox_fun/dist/cogvideox_xfuser.py +116 -0
  27. videox_fun/dist/wan_xfuser.py +115 -0
  28. videox_fun/models/__init__.py +4 -0
  29. videox_fun/models/cache_utils.py +74 -0
  30. videox_fun/models/cogvideox_transformer3d.py +845 -0
  31. videox_fun/models/cogvideox_vae.py +1675 -0
  32. videox_fun/pipeline/__init__.py +2 -0
  33. videox_fun/pipeline/pipeline_cogvideox_fun.py +862 -0
  34. videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py +1244 -0
  35. videox_fun/pipeline/pipeline_wan_fun.py +558 -0
  36. videox_fun/reward/MPS/README.md +1 -0
  37. videox_fun/reward/MPS/trainer/models/base_model.py +7 -0
  38. videox_fun/reward/MPS/trainer/models/clip_model.py +154 -0
  39. videox_fun/reward/MPS/trainer/models/cross_modeling.py +291 -0
  40. videox_fun/reward/aesthetic_predictor_v2_5/__init__.py +13 -0
  41. videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py +133 -0
  42. videox_fun/reward/improved_aesthetic_predictor.py +49 -0
  43. videox_fun/reward/reward_fn.py +385 -0
  44. videox_fun/ui/cogvideox_fun_ui.py +667 -0
  45. videox_fun/ui/ui.py +290 -0
  46. videox_fun/ui/wan_fun_ui.py +630 -0
  47. videox_fun/utils/__init__.py +0 -0
  48. videox_fun/utils/discrete_sampler.py +46 -0
  49. videox_fun/utils/fp8_optimization.py +56 -0
  50. videox_fun/utils/lora_utils.py +516 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample/lime/first_frame.jpg filter=lfs diff=lfs merge=lfs -text
37
+ sample/moving_ball/first_frame.jpg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,7 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ """
2
+ VOID – Video Object and Interaction Deletion
3
+ Gradio demo for Hugging Face Spaces (ZeroGPU)
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import tempfile
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import imageio
14
+ import mediapy as media
15
+ import spaces
16
  import gradio as gr
17
+ from huggingface_hub import hf_hub_download
18
+ from safetensors.torch import load_file
19
+ from diffusers import DDIMScheduler
20
+ from PIL import Image
21
+
22
+ # ── project imports ────────────────────────────────────────────────────────────
23
+ sys.path.insert(0, os.path.dirname(__file__))
24
+
25
+ from videox_fun.models import (
26
+ AutoencoderKLCogVideoX,
27
+ CogVideoXTransformer3DModel,
28
+ T5EncoderModel,
29
+ T5Tokenizer,
30
+ )
31
+ from videox_fun.pipeline import CogVideoXFunInpaintPipeline
32
+ from videox_fun.utils.fp8_optimization import convert_weight_dtype_wrapper
33
+ from videox_fun.utils.utils import temporal_padding
34
+
35
+ # ── constants ──────────────────────────────────────────────────────────────────
36
+ # Set these env vars in your HF Space settings, or hardcode once weights are public.
37
+ BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP")
38
+ VOID_MODEL_ID = os.environ.get("VOID_MODEL_ID", "your-hf-username/VOID")
39
+ VOID_CKPT_FILE = "void_pass1.safetensors"
40
+
41
+ SAMPLE_SIZE = (384, 672) # H × W
42
+ MAX_VID_LEN = 197
43
+ TEMPORAL_WIN = 85
44
+ FPS = 12
45
+ WEIGHT_DTYPE = torch.bfloat16
46
+ NEG_PROMPT = (
47
+ "The video is not of a high quality, it has a low resolution. "
48
+ "Watermark present in each frame. The background is solid. "
49
+ "Strange body and strange trajectory. Distortion."
50
+ )
51
+
52
+ # ── model loading (once at startup, lives in CPU RAM between GPU requests) ─────
53
+ print("Loading VOID pipeline …")
54
+
55
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
56
+ BASE_MODEL_ID,
57
+ subfolder="transformer",
58
+ low_cpu_mem_usage=True,
59
+ torch_dtype=torch.float8_e4m3fn, # qfloat8 to save VRAM
60
+ use_vae_mask=True,
61
+ stack_mask=False,
62
+ ).to(WEIGHT_DTYPE)
63
+
64
+ # Load VOID Pass-1 checkpoint
65
+ ckpt_path = hf_hub_download(repo_id=VOID_MODEL_ID, filename=VOID_CKPT_FILE)
66
+ state_dict = load_file(ckpt_path)
67
+ state_dict = state_dict.get("state_dict", state_dict)
68
+
69
+ # Adapt patch_embed channels if they differ (mask-conditioning channels added)
70
+ param_name = "patch_embed.proj.weight"
71
+ if state_dict[param_name].size(1) != transformer.state_dict()[param_name].size(1):
72
+ feat_dim = 16 * 8 # latent_channels * feat_scale
73
+ new_weight = transformer.state_dict()[param_name].clone()
74
+ new_weight[:, :feat_dim] = state_dict[param_name][:, :feat_dim]
75
+ new_weight[:, -feat_dim:] = state_dict[param_name][:, -feat_dim:]
76
+ state_dict[param_name] = new_weight
77
+
78
+ transformer.load_state_dict(state_dict, strict=False)
79
+
80
+ vae = AutoencoderKLCogVideoX.from_pretrained(
81
+ BASE_MODEL_ID, subfolder="vae"
82
+ ).to(WEIGHT_DTYPE)
83
+ tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_ID, subfolder="tokenizer")
84
+ text_encoder = T5EncoderModel.from_pretrained(
85
+ BASE_MODEL_ID, subfolder="text_encoder", torch_dtype=WEIGHT_DTYPE
86
+ )
87
+ scheduler = DDIMScheduler.from_pretrained(BASE_MODEL_ID, subfolder="scheduler")
88
+
89
+ pipeline = CogVideoXFunInpaintPipeline(
90
+ vae=vae,
91
+ tokenizer=tokenizer,
92
+ text_encoder=text_encoder,
93
+ transformer=transformer,
94
+ scheduler=scheduler,
95
+ )
96
+ convert_weight_dtype_wrapper(transformer, WEIGHT_DTYPE)
97
+ pipeline.enable_model_cpu_offload()
98
+
99
+ print("VOID pipeline ready.")
100
+
101
+
102
+ # ── helpers ────────────────────────────────────────────────────────────────────
103
+ def load_video_tensor(path: str) -> torch.Tensor:
104
+ """Return (1, C, T, H, W) float32 in [0, 1] resized to SAMPLE_SIZE."""
105
+ frames = media.read_video(path)
106
+ t = torch.from_numpy(np.array(frames))[:MAX_VID_LEN] # (T, H, W, C)
107
+ t = t.permute(3, 0, 1, 2).float() / 255.0 # (C, T, H, W)
108
+ t = F.interpolate(t, SAMPLE_SIZE, mode="area").unsqueeze(0)
109
+ return t
110
+
111
+
112
+ def load_quadmask_tensor(path: str) -> torch.Tensor:
113
+ """
114
+ Return (1, 1, T, H, W) float32 in [0, 1].
115
+
116
+ Quadmask pixel values:
117
+ 0 → primary object (to erase)
118
+ 63 → overlap / interaction zone
119
+ 127 → affected region (shadows, reflections …)
120
+ 255 → background (keep)
121
+
122
+ After quantisation the mask is inverted so 255 = "erase", 0 = "keep",
123
+ matching the pipeline's internal convention.
124
+ """
125
+ frames = media.read_video(path)[:MAX_VID_LEN]
126
+ if frames.ndim == 4:
127
+ frames = frames[..., 0] # take first channel, grayscale
128
+ m = torch.from_numpy(np.array(frames)).unsqueeze(0).float() # (1, T, H, W)
129
+ m = F.interpolate(m, SAMPLE_SIZE, mode="area").unsqueeze(0) # (1, 1, T, H, W)
130
+
131
+ # Quantise to four canonical values
132
+ m = torch.where(m <= 31, torch.zeros_like(m), m)
133
+ m = torch.where((m > 31) & (m <= 95), torch.full_like(m, 63), m)
134
+ m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127), m)
135
+ m = torch.where(m > 191, torch.full_like(m, 255), m)
136
+
137
+ m = 255.0 - m # invert
138
+ return m / 255.0
139
+
140
+
141
+ def tensor_to_mp4(video: torch.Tensor) -> str:
142
+ """Save (1, C, T, H, W) in [0, 1] to a temp mp4 and return the path."""
143
+ frames = video[0].permute(1, 2, 3, 0).cpu().float().numpy() # (T, H, W, C)
144
+ frames = (frames * 255).clip(0, 255).astype(np.uint8)
145
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
146
+ imageio.mimsave(tmp.name, frames, fps=FPS)
147
+ return tmp.name
148
+
149
+
150
+ # ── inference ──────────────────────────────────────────────────────────────────
151
+ @spaces.GPU(duration=300)
152
+ def run_inpaint(
153
+ input_video_path: str,
154
+ mask_video_path: str,
155
+ prompt: str,
156
+ num_steps: int,
157
+ guidance_scale: float,
158
+ seed: int,
159
+ ) -> str:
160
+ if not input_video_path or not mask_video_path:
161
+ raise gr.Error("Please upload both an input video and a quadmask video.")
162
+ if not prompt.strip():
163
+ raise gr.Error("Please enter a prompt describing the scene after removal.")
164
+
165
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
166
+
167
+ input_video = load_video_tensor(input_video_path)
168
+ input_mask = load_quadmask_tensor(mask_video_path)
169
+
170
+ input_video = temporal_padding(input_video, min_length=TEMPORAL_WIN, max_length=MAX_VID_LEN)
171
+ input_mask = temporal_padding(input_mask, min_length=TEMPORAL_WIN, max_length=MAX_VID_LEN)
172
+
173
+ with torch.no_grad():
174
+ result = pipeline(
175
+ prompt=prompt,
176
+ negative_prompt=NEG_PROMPT,
177
+ height=SAMPLE_SIZE[0],
178
+ width=SAMPLE_SIZE[1],
179
+ num_frames=TEMPORAL_WIN,
180
+ video=input_video,
181
+ mask_video=input_mask,
182
+ generator=generator,
183
+ guidance_scale=guidance_scale,
184
+ num_inference_steps=num_steps,
185
+ strength=1.0,
186
+ use_trimask=True,
187
+ use_vae_mask=True,
188
+ stack_mask=False,
189
+ zero_out_mask_region=False,
190
+ ).videos
191
+
192
+ return tensor_to_mp4(result)
193
+
194
+
195
+ # ── Gradio UI ──────────────────────────────────────────────────────────────────
196
+ QUADMASK_EXPLAINER = """
197
+ ### Quadmask format
198
+
199
+ The quadmask is a **grayscale video** where each pixel value encodes what role that region plays:
200
+
201
+ | Pixel value | Meaning |
202
+ |-------------|---------|
203
+ | **0** (black) | Primary object to remove |
204
+ | **63** (dark grey) | Overlap / interaction zone |
205
+ | **127** (mid grey) | Affected region — shadows, reflections, secondary effects |
206
+ | **255** (white) | Background — keep as-is |
207
+
208
+ Use the **VLM-Mask-Reasoner** pipeline included in the repo to generate quadmasks automatically.
209
+ """
210
+
211
+ SAMPLE_DIR = os.path.join(os.path.dirname(__file__), "sample")
212
+ EXAMPLES = [
213
+ [
214
+ os.path.join(SAMPLE_DIR, "lime", "input_video.mp4"),
215
+ os.path.join(SAMPLE_DIR, "lime", "quadmask_0.mp4"),
216
+ "A lime falls on the table.",
217
+ 30, 1.0, 42,
218
+ ],
219
+ [
220
+ os.path.join(SAMPLE_DIR, "moving_ball", "input_video.mp4"),
221
+ os.path.join(SAMPLE_DIR, "moving_ball", "quadmask_0.mp4"),
222
+ "A ball rolls off the table.",
223
+ 30, 1.0, 42,
224
+ ],
225
+ [
226
+ os.path.join(SAMPLE_DIR, "pillow", "input_video.mp4"),
227
+ os.path.join(SAMPLE_DIR, "pillow", "quadmask_0.mp4"),
228
+ "Two pillows placed on the table.",
229
+ 30, 1.0, 42,
230
+ ],
231
+ ]
232
+
233
+ with gr.Blocks(title="VOID – Video Object & Interaction Deletion") as demo:
234
+ gr.Markdown(
235
+ """
236
+ # VOID – Video Object and Interaction Deletion
237
+
238
+ Upload a video and its **quadmask**, enter a prompt describing the scene *after* removal,
239
+ and VOID will erase the object along with its physical interactions (shadows, deformations, secondary motion).
240
+
241
+ > Built on **CogVideoX-Fun-V1.5-5B** fine-tuned for interaction-aware video inpainting.
242
+ """
243
+ )
244
+
245
+ with gr.Row():
246
+ with gr.Column():
247
+ input_video = gr.Video(label="Input video", sources=["upload"])
248
+ mask_video = gr.Video(label="Quadmask video", sources=["upload"])
249
+ prompt = gr.Textbox(
250
+ label="Prompt — describe the scene after removal",
251
+ placeholder="e.g. A wooden table with nothing on it.",
252
+ lines=2,
253
+ )
254
+ with gr.Accordion("Advanced settings", open=False):
255
+ num_steps = gr.Slider(10, 50, value=30, step=1, label="Inference steps")
256
+ guidance_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="Guidance scale")
257
+ seed = gr.Number(value=42, label="Seed", precision=0)
258
+ run_btn = gr.Button("Run VOID", variant="primary")
259
+
260
+ with gr.Column():
261
+ output_video = gr.Video(label="Inpainted output", interactive=False)
262
+
263
+ gr.Markdown(QUADMASK_EXPLAINER)
264
+
265
+ gr.Examples(
266
+ examples=EXAMPLES,
267
+ inputs=[input_video, mask_video, prompt, num_steps, guidance_scale, seed],
268
+ outputs=[output_video],
269
+ fn=run_inpaint,
270
+ cache_examples=True,
271
+ label="Sample sequences — click to load and run",
272
+ )
273
 
274
+ run_btn.click(
275
+ fn=run_inpaint,
276
+ inputs=[input_video, mask_video, prompt, num_steps, guidance_scale, seed],
277
+ outputs=[output_video],
278
+ )
279
 
280
+ if __name__ == "__main__":
281
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core deep learning
2
+ torch==2.7.1
3
+ torchvision==0.22.1
4
+ torchdiffeq==0.2.5
5
+ torchsde==0.2.6
6
+
7
+ # Diffusion / generation
8
+ diffusers==0.33.1
9
+ accelerate==1.12.0
10
+ transformers==4.57.1
11
+ safetensors==0.6.2
12
+ peft==0.17.1
13
+
14
+ # Training utilities
15
+ deepspeed==0.17.6
16
+ came-pytorch==0.1.3
17
+ tensorboard==2.20.0
18
+
19
+ # Vision / video
20
+ opencv-python==4.10.0.84
21
+ scikit-image==0.25.2
22
+ imageio==2.37.0
23
+ imageio-ffmpeg==0.6.0
24
+ mediapy==1.2.4
25
+ decord==0.6.0
26
+ kornia==0.8.1
27
+ albumentations==2.0.8
28
+ timm==1.0.19
29
+ tomesd==0.1.3
30
+ Pillow==11.3.0
31
+
32
+ # Data / ML utilities
33
+ numpy==1.26.4
34
+ scipy==1.14.0
35
+ scikit-learn==1.7.2
36
+ datasets==4.0.0
37
+ einops==0.8.0
38
+
39
+ # Config / logging
40
+ omegaconf==2.3.0
41
+ ml_collections==1.1.0
42
+ absl-py==2.3.1
43
+ loguru==0.7.3
44
+ tqdm==4.67.1
45
+ matplotlib==3.10.6
46
+
47
+ # NLP
48
+ sentencepiece==0.2.1
49
+ ftfy==6.1.1
50
+ beautifulsoup4==4.13.5
51
+
52
+ # Misc
53
+ func-timeout==4.3.5
54
+ requests==2.32.5
55
+ packaging==25.0
56
+
57
+ # Optional: Gradio UI (only needed for app.py / demo)
58
+ # gradio>=3.41.2,<=3.48.0
59
+
60
+ # Note: SAM2 must be installed separately per the instructions at
61
+ # https://github.com/facebookresearch/sam2?tab=readme-ov-file#installation
sample/lime/first_frame.jpg ADDED

Git LFS Details

  • SHA256: eb8e417430f5cc3ee15bffcd60c51b13461c567c3809dedeaeaca55efb567c06
  • Pointer size: 131 Bytes
  • Size of remote file: 894 kB
sample/lime/input_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0efabfbfc85bf29d11ac0f734eccf5dc824c511333c15953b73d3e357d7d9a87
3
+ size 3892459
sample/lime/prompt.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "bg": "A lime falls on the table."
3
+ }
sample/lime/quadmask_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00a01b7fb47107edcbfd5a036d6d7b1097ea8624df9c2440d184ddfa90a8bdd5
3
+ size 1907329
sample/lime/segmentation_info.json ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total_frames": 46,
3
+ "frame_width": 3840,
4
+ "frame_height": 2160,
5
+ "fps": 12.0,
6
+ "num_points": 25,
7
+ "points_by_frame": {
8
+ "0": [
9
+ [
10
+ 2126,
11
+ 1099
12
+ ],
13
+ [
14
+ 2366,
15
+ 1099
16
+ ],
17
+ [
18
+ 2683,
19
+ 1080
20
+ ],
21
+ [
22
+ 2784,
23
+ 1176
24
+ ],
25
+ [
26
+ 2640,
27
+ 1176
28
+ ],
29
+ [
30
+ 2539,
31
+ 1176
32
+ ],
33
+ [
34
+ 2318,
35
+ 1176
36
+ ],
37
+ [
38
+ 2116,
39
+ 1291
40
+ ],
41
+ [
42
+ 2496,
43
+ 1291
44
+ ],
45
+ [
46
+ 2654,
47
+ 1286
48
+ ],
49
+ [
50
+ 2654,
51
+ 1406
52
+ ],
53
+ [
54
+ 2342,
55
+ 1406
56
+ ],
57
+ [
58
+ 2342,
59
+ 1776
60
+ ],
61
+ [
62
+ 2620,
63
+ 1776
64
+ ],
65
+ [
66
+ 2539,
67
+ 1924
68
+ ],
69
+ [
70
+ 2304,
71
+ 1972
72
+ ],
73
+ [
74
+ 2217,
75
+ 1992
76
+ ],
77
+ [
78
+ 2385,
79
+ 2030
80
+ ],
81
+ [
82
+ 2596,
83
+ 2025
84
+ ],
85
+ [
86
+ 2673,
87
+ 1987
88
+ ],
89
+ [
90
+ 2217,
91
+ 1776
92
+ ],
93
+ [
94
+ 2198,
95
+ 1660
96
+ ],
97
+ [
98
+ 2452,
99
+ 1588
100
+ ],
101
+ [
102
+ 2294,
103
+ 1483
104
+ ],
105
+ [
106
+ 2270,
107
+ 1358
108
+ ]
109
+ ]
110
+ },
111
+ "video_path": "limecoke.mp4",
112
+ "instruction": "",
113
+ "primary_points_by_frame": {
114
+ "0": [
115
+ [
116
+ 2126,
117
+ 1099
118
+ ],
119
+ [
120
+ 2366,
121
+ 1099
122
+ ],
123
+ [
124
+ 2683,
125
+ 1080
126
+ ],
127
+ [
128
+ 2784,
129
+ 1176
130
+ ],
131
+ [
132
+ 2640,
133
+ 1176
134
+ ],
135
+ [
136
+ 2539,
137
+ 1176
138
+ ],
139
+ [
140
+ 2318,
141
+ 1176
142
+ ],
143
+ [
144
+ 2116,
145
+ 1291
146
+ ],
147
+ [
148
+ 2496,
149
+ 1291
150
+ ],
151
+ [
152
+ 2654,
153
+ 1286
154
+ ],
155
+ [
156
+ 2654,
157
+ 1406
158
+ ],
159
+ [
160
+ 2342,
161
+ 1406
162
+ ],
163
+ [
164
+ 2342,
165
+ 1776
166
+ ],
167
+ [
168
+ 2620,
169
+ 1776
170
+ ],
171
+ [
172
+ 2539,
173
+ 1924
174
+ ],
175
+ [
176
+ 2304,
177
+ 1972
178
+ ],
179
+ [
180
+ 2217,
181
+ 1992
182
+ ],
183
+ [
184
+ 2385,
185
+ 2030
186
+ ],
187
+ [
188
+ 2596,
189
+ 2025
190
+ ],
191
+ [
192
+ 2673,
193
+ 1987
194
+ ],
195
+ [
196
+ 2217,
197
+ 1776
198
+ ],
199
+ [
200
+ 2198,
201
+ 1660
202
+ ],
203
+ [
204
+ 2452,
205
+ 1588
206
+ ],
207
+ [
208
+ 2294,
209
+ 1483
210
+ ],
211
+ [
212
+ 2270,
213
+ 1358
214
+ ]
215
+ ]
216
+ },
217
+ "primary_frames": [
218
+ 0
219
+ ],
220
+ "first_appears_frame": 0
221
+ }
sample/moving_ball/first_frame.jpg ADDED

Git LFS Details

  • SHA256: 32773994491b764d09cf357983abd4cbd89dd9601fa45a5c2ed1a340ab70df90
  • Pointer size: 131 Bytes
  • Size of remote file: 653 kB
sample/moving_ball/input_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e07906cc204ba26c0dd05eed545030cb7e79f2742e983ff0b04d2d9c3c762d29
3
+ size 2014662
sample/moving_ball/prompt.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "bg": "A ball rolls off the table."
3
+ }
sample/moving_ball/quadmask_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5904642de05a65f210bd49e3c24b7d0657ef57ff40eb9baafd562962c9dd9189
3
+ size 2485881
sample/pillow/input_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ca3e6b666497e053491772e8f0317e22520c63ebaa8896b8378757d016e0f75
3
+ size 2960087
sample/pillow/prompt.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "bg": "Two pillows placed on the table."
3
+ }
sample/pillow/quadmask_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7eb70257593da06f682a3ddda54a9d260d4fc514f645237f5ca74b08f8da61a6
3
+ size 2
sample/pillow/segmentation_info.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total_frames": 62,
3
+ "frame_width": 3840,
4
+ "frame_height": 2160,
5
+ "fps": 12.0,
6
+ "num_points": 8,
7
+ "points_by_frame": {
8
+ "0": [
9
+ [
10
+ 1507,
11
+ 724
12
+ ],
13
+ [
14
+ 1363,
15
+ 638
16
+ ],
17
+ [
18
+ 1190,
19
+ 475
20
+ ],
21
+ [
22
+ 1276,
23
+ 187
24
+ ],
25
+ [
26
+ 1545,
27
+ 168
28
+ ],
29
+ [
30
+ 1660,
31
+ 259
32
+ ],
33
+ [
34
+ 1684,
35
+ 393
36
+ ],
37
+ [
38
+ 1579,
39
+ 825
40
+ ]
41
+ ]
42
+ },
43
+ "video_path": "teaser3/weight_on_pillow.mp4",
44
+ "instruction": "segment the weight",
45
+ "primary_points_by_frame": {
46
+ "0": [
47
+ [
48
+ 1507,
49
+ 724
50
+ ],
51
+ [
52
+ 1363,
53
+ 638
54
+ ],
55
+ [
56
+ 1190,
57
+ 475
58
+ ],
59
+ [
60
+ 1276,
61
+ 187
62
+ ],
63
+ [
64
+ 1545,
65
+ 168
66
+ ],
67
+ [
68
+ 1660,
69
+ 259
70
+ ],
71
+ [
72
+ 1684,
73
+ 393
74
+ ],
75
+ [
76
+ 1579,
77
+ 825
78
+ ]
79
+ ]
80
+ },
81
+ "primary_frames": [
82
+ 0
83
+ ],
84
+ "first_appears_frame": 0
85
+ }
videox_fun/__init__.py ADDED
File without changes
videox_fun/api/api.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gc
3
+ import hashlib
4
+ import io
5
+ import os
6
+ import tempfile
7
+ from io import BytesIO
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from fastapi import FastAPI
13
+ from PIL import Image
14
+
15
+
16
+ # Function to encode a file to Base64
17
+ def encode_file_to_base64(file_path):
18
+ with open(file_path, "rb") as file:
19
+ # Encode the data to Base64
20
+ file_base64 = base64.b64encode(file.read())
21
+ return file_base64
22
+
23
+ def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
24
+ @app.post("/videox_fun/update_edition")
25
+ def _update_edition_api(
26
+ datas: dict,
27
+ ):
28
+ edition = datas.get('edition', 'v2')
29
+
30
+ try:
31
+ controller.update_edition(
32
+ edition
33
+ )
34
+ comment = "Success"
35
+ except Exception as e:
36
+ torch.cuda.empty_cache()
37
+ comment = f"Error. error information is {str(e)}"
38
+
39
+ return {"message": comment}
40
+
41
+ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
42
+ @app.post("/videox_fun/update_diffusion_transformer")
43
+ def _update_diffusion_transformer_api(
44
+ datas: dict,
45
+ ):
46
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
47
+
48
+ try:
49
+ controller.update_diffusion_transformer(
50
+ diffusion_transformer_path
51
+ )
52
+ comment = "Success"
53
+ except Exception as e:
54
+ torch.cuda.empty_cache()
55
+ comment = f"Error. error information is {str(e)}"
56
+
57
+ return {"message": comment}
58
+
59
+ def download_from_url(url, timeout=10):
60
+ try:
61
+ response = requests.get(url, timeout=timeout)
62
+ response.raise_for_status() # 检查请求是否成功
63
+ return response.content
64
+ except requests.exceptions.RequestException as e:
65
+ print(f"Error downloading from {url}: {e}")
66
+ return None
67
+
68
+ def save_base64_video(base64_string):
69
+ video_data = base64.b64decode(base64_string)
70
+
71
+ md5_hash = hashlib.md5(video_data).hexdigest()
72
+ filename = f"{md5_hash}.mp4"
73
+
74
+ temp_dir = tempfile.gettempdir()
75
+ file_path = os.path.join(temp_dir, filename)
76
+
77
+ with open(file_path, 'wb') as video_file:
78
+ video_file.write(video_data)
79
+
80
+ return file_path
81
+
82
+ def save_base64_image(base64_string):
83
+ video_data = base64.b64decode(base64_string)
84
+
85
+ md5_hash = hashlib.md5(video_data).hexdigest()
86
+ filename = f"{md5_hash}.jpg"
87
+
88
+ temp_dir = tempfile.gettempdir()
89
+ file_path = os.path.join(temp_dir, filename)
90
+
91
+ with open(file_path, 'wb') as video_file:
92
+ video_file.write(video_data)
93
+
94
+ return file_path
95
+
96
+ def save_url_video(url):
97
+ video_data = download_from_url(url)
98
+ if video_data:
99
+ return save_base64_video(base64.b64encode(video_data))
100
+ return None
101
+
102
+ def save_url_image(url):
103
+ image_data = download_from_url(url)
104
+ if image_data:
105
+ return save_base64_image(base64.b64encode(image_data))
106
+ return None
107
+
108
+ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
109
+ @app.post("/videox_fun/infer_forward")
110
+ def _infer_forward_api(
111
+ datas: dict,
112
+ ):
113
+ base_model_path = datas.get('base_model_path', 'none')
114
+ lora_model_path = datas.get('lora_model_path', 'none')
115
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
116
+ prompt_textbox = datas.get('prompt_textbox', None)
117
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
118
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
119
+ sample_step_slider = datas.get('sample_step_slider', 30)
120
+ resize_method = datas.get('resize_method', "Generate by")
121
+ width_slider = datas.get('width_slider', 672)
122
+ height_slider = datas.get('height_slider', 384)
123
+ base_resolution = datas.get('base_resolution', 512)
124
+ is_image = datas.get('is_image', False)
125
+ generation_method = datas.get('generation_method', False)
126
+ length_slider = datas.get('length_slider', 49)
127
+ overlap_video_length = datas.get('overlap_video_length', 4)
128
+ partial_video_length = datas.get('partial_video_length', 72)
129
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
130
+ start_image = datas.get('start_image', None)
131
+ end_image = datas.get('end_image', None)
132
+ validation_video = datas.get('validation_video', None)
133
+ validation_video_mask = datas.get('validation_video_mask', None)
134
+ control_video = datas.get('control_video', None)
135
+ denoise_strength = datas.get('denoise_strength', 0.70)
136
+ seed_textbox = datas.get("seed_textbox", 43)
137
+
138
+ generation_method = "Image Generation" if is_image else generation_method
139
+
140
+ if start_image is not None:
141
+ if start_image.startswith('http'):
142
+ start_image = save_url_image(start_image)
143
+ start_image = [Image.open(start_image)]
144
+ else:
145
+ start_image = base64.b64decode(start_image)
146
+ start_image = [Image.open(BytesIO(start_image))]
147
+
148
+ if end_image is not None:
149
+ if end_image.startswith('http'):
150
+ end_image = save_url_image(end_image)
151
+ end_image = [Image.open(end_image)]
152
+ else:
153
+ end_image = base64.b64decode(end_image)
154
+ end_image = [Image.open(BytesIO(end_image))]
155
+
156
+ if validation_video is not None:
157
+ if validation_video.startswith('http'):
158
+ validation_video = save_url_video(validation_video)
159
+ else:
160
+ validation_video = save_base64_video(validation_video)
161
+
162
+ if validation_video_mask is not None:
163
+ if validation_video_mask.startswith('http'):
164
+ validation_video_mask = save_url_image(validation_video_mask)
165
+ else:
166
+ validation_video_mask = save_base64_image(validation_video_mask)
167
+
168
+ if control_video is not None:
169
+ if control_video.startswith('http'):
170
+ control_video = save_url_video(control_video)
171
+ else:
172
+ control_video = save_base64_video(control_video)
173
+
174
+ try:
175
+ save_sample_path, comment = controller.generate(
176
+ "",
177
+ base_model_path,
178
+ lora_model_path,
179
+ lora_alpha_slider,
180
+ prompt_textbox,
181
+ negative_prompt_textbox,
182
+ sampler_dropdown,
183
+ sample_step_slider,
184
+ resize_method,
185
+ width_slider,
186
+ height_slider,
187
+ base_resolution,
188
+ generation_method,
189
+ length_slider,
190
+ overlap_video_length,
191
+ partial_video_length,
192
+ cfg_scale_slider,
193
+ start_image,
194
+ end_image,
195
+ validation_video,
196
+ validation_video_mask,
197
+ control_video,
198
+ denoise_strength,
199
+ seed_textbox,
200
+ is_api = True,
201
+ )
202
+ except Exception as e:
203
+ gc.collect()
204
+ torch.cuda.empty_cache()
205
+ torch.cuda.ipc_collect()
206
+ save_sample_path = ""
207
+ comment = f"Error. error information is {str(e)}"
208
+ return {"message": comment}
209
+
210
+ if save_sample_path != "":
211
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
212
+ else:
213
+ return {"message": comment, "save_sample_path": save_sample_path}
videox_fun/api/api_multi_nodes.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
2
+ import base64
3
+ import gc
4
+ import os
5
+ from io import BytesIO
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from fastapi import FastAPI, HTTPException
10
+ from PIL import Image
11
+
12
+ from .api import (encode_file_to_base64, save_base64_image, save_base64_video,
13
+ save_url_image, save_url_video)
14
+
15
+ try:
16
+ import ray
17
+ except:
18
+ print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
19
+ ray = None
20
+
21
+ if ray is not None:
22
+ @ray.remote(num_gpus=1)
23
+ class MultiNodesGenerator:
24
+ def __init__(
25
+ self, rank: int, world_size: int, Controller,
26
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
27
+ config_path=None, ulysses_degree=1, ring_degree=1,
28
+ enable_teacache=None, teacache_threshold=None,
29
+ num_skip_start_steps=None, teacache_offload=None, weight_dtype=None,
30
+ savedir_sample=None,
31
+ ):
32
+ # Set PyTorch distributed environment variables
33
+ os.environ["RANK"] = str(rank)
34
+ os.environ["WORLD_SIZE"] = str(world_size)
35
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
36
+ os.environ["MASTER_PORT"] = "29500"
37
+
38
+ self.rank = rank
39
+ self.controller = Controller(
40
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
41
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree, enable_teacache=enable_teacache, teacache_threshold=teacache_threshold, num_skip_start_steps=num_skip_start_steps,
42
+ teacache_offload=teacache_offload, weight_dtype=weight_dtype, savedir_sample=savedir_sample,
43
+ )
44
+
45
+ def generate(self, datas):
46
+ try:
47
+ base_model_path = datas.get('base_model_path', 'none')
48
+ lora_model_path = datas.get('lora_model_path', 'none')
49
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
50
+ prompt_textbox = datas.get('prompt_textbox', None)
51
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
52
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
53
+ sample_step_slider = datas.get('sample_step_slider', 30)
54
+ resize_method = datas.get('resize_method', "Generate by")
55
+ width_slider = datas.get('width_slider', 672)
56
+ height_slider = datas.get('height_slider', 384)
57
+ base_resolution = datas.get('base_resolution', 512)
58
+ is_image = datas.get('is_image', False)
59
+ generation_method = datas.get('generation_method', False)
60
+ length_slider = datas.get('length_slider', 49)
61
+ overlap_video_length = datas.get('overlap_video_length', 4)
62
+ partial_video_length = datas.get('partial_video_length', 72)
63
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
64
+ start_image = datas.get('start_image', None)
65
+ end_image = datas.get('end_image', None)
66
+ validation_video = datas.get('validation_video', None)
67
+ validation_video_mask = datas.get('validation_video_mask', None)
68
+ control_video = datas.get('control_video', None)
69
+ denoise_strength = datas.get('denoise_strength', 0.70)
70
+ seed_textbox = datas.get("seed_textbox", 43)
71
+
72
+ generation_method = "Image Generation" if is_image else generation_method
73
+
74
+ if start_image is not None:
75
+ if start_image.startswith('http'):
76
+ start_image = save_url_image(start_image)
77
+ start_image = [Image.open(start_image)]
78
+ else:
79
+ start_image = base64.b64decode(start_image)
80
+ start_image = [Image.open(BytesIO(start_image))]
81
+
82
+ if end_image is not None:
83
+ if end_image.startswith('http'):
84
+ end_image = save_url_image(end_image)
85
+ end_image = [Image.open(end_image)]
86
+ else:
87
+ end_image = base64.b64decode(end_image)
88
+ end_image = [Image.open(BytesIO(end_image))]
89
+
90
+ if validation_video is not None:
91
+ if validation_video.startswith('http'):
92
+ validation_video = save_url_video(validation_video)
93
+ else:
94
+ validation_video = save_base64_video(validation_video)
95
+
96
+ if validation_video_mask is not None:
97
+ if validation_video_mask.startswith('http'):
98
+ validation_video_mask = save_url_image(validation_video_mask)
99
+ else:
100
+ validation_video_mask = save_base64_image(validation_video_mask)
101
+
102
+ if control_video is not None:
103
+ if control_video.startswith('http'):
104
+ control_video = save_url_video(control_video)
105
+ else:
106
+ control_video = save_base64_video(control_video)
107
+
108
+ try:
109
+ save_sample_path, comment = self.controller.generate(
110
+ "",
111
+ base_model_path,
112
+ lora_model_path,
113
+ lora_alpha_slider,
114
+ prompt_textbox,
115
+ negative_prompt_textbox,
116
+ sampler_dropdown,
117
+ sample_step_slider,
118
+ resize_method,
119
+ width_slider,
120
+ height_slider,
121
+ base_resolution,
122
+ generation_method,
123
+ length_slider,
124
+ overlap_video_length,
125
+ partial_video_length,
126
+ cfg_scale_slider,
127
+ start_image,
128
+ end_image,
129
+ validation_video,
130
+ validation_video_mask,
131
+ control_video,
132
+ denoise_strength,
133
+ seed_textbox,
134
+ is_api = True,
135
+ )
136
+ except Exception as e:
137
+ gc.collect()
138
+ torch.cuda.empty_cache()
139
+ torch.cuda.ipc_collect()
140
+ save_sample_path = ""
141
+ comment = f"Error. error information is {str(e)}"
142
+ return {"message": comment}
143
+
144
+ import torch.distributed as dist
145
+ if dist.get_rank() == 0:
146
+ if save_sample_path != "":
147
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
148
+ else:
149
+ return {"message": comment, "save_sample_path": save_sample_path}
150
+ return None
151
+
152
+ except Exception as e:
153
+ self.logger.error(f"Error generating image: {str(e)}")
154
+ raise HTTPException(status_code=500, detail=str(e))
155
+
156
+ class MultiNodesEngine:
157
+ def __init__(
158
+ self,
159
+ world_size,
160
+ Controller,
161
+ GPU_memory_mode,
162
+ scheduler_dict,
163
+ model_name,
164
+ model_type,
165
+ config_path,
166
+ ulysses_degree,
167
+ ring_degree,
168
+ enable_teacache,
169
+ teacache_threshold,
170
+ num_skip_start_steps,
171
+ teacache_offload,
172
+ weight_dtype,
173
+ savedir_sample
174
+ ):
175
+ # Ensure Ray is initialized
176
+ if not ray.is_initialized():
177
+ ray.init()
178
+
179
+ num_workers = world_size
180
+ self.workers = [
181
+ MultiNodesGenerator.remote(
182
+ rank, world_size, Controller,
183
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
184
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree, enable_teacache=enable_teacache, teacache_threshold=teacache_threshold, num_skip_start_steps=num_skip_start_steps,
185
+ teacache_offload=teacache_offload, weight_dtype=weight_dtype, savedir_sample=savedir_sample,
186
+ )
187
+ for rank in range(num_workers)
188
+ ]
189
+ print("Update workers done")
190
+
191
+ async def generate(self, data):
192
+ results = ray.get([
193
+ worker.generate.remote(data)
194
+ for worker in self.workers
195
+ ])
196
+
197
+ return next(path for path in results if path is not None)
198
+
199
+ def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
200
+
201
+ @app.post("/videox_fun/infer_forward")
202
+ async def _multi_nodes_infer_forward_api(
203
+ datas: dict,
204
+ ):
205
+ try:
206
+ result = await engine.generate(datas)
207
+ return result
208
+ except Exception as e:
209
+ if isinstance(e, HTTPException):
210
+ raise e
211
+ raise HTTPException(status_code=500, detail=str(e))
212
+ else:
213
+ MultiNodesEngine = None
214
+ MultiNodesGenerator = None
215
+ multi_nodes_infer_forward_api = None
videox_fun/data/bucket_sampler.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import glob
4
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
5
+ Sized, TypeVar, Union)
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from torch.utils.data import BatchSampler, Dataset, Sampler
12
+
13
+ ASPECT_RATIO_512 = {
14
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
15
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
16
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
17
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
18
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
19
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
20
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
21
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
22
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
23
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
24
+ }
25
+ ASPECT_RATIO_RANDOM_CROP_512 = {
26
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
27
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
28
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
29
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
30
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
31
+ }
32
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
33
+ 1, 2,
34
+ 4, 4, 4, 4,
35
+ 8, 8, 8,
36
+ 4, 4, 4, 4,
37
+ 2, 1
38
+ ]
39
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
40
+
41
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
42
+ aspect_ratio = height / width
43
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
44
+ return ratios[closest_ratio], float(closest_ratio)
45
+
46
+ def get_image_size_without_loading(path):
47
+ with Image.open(path) as img:
48
+ return img.size # (width, height)
49
+
50
+ class RandomSampler(Sampler[int]):
51
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
52
+
53
+ If with replacement, then user can specify :attr:`num_samples` to draw.
54
+
55
+ Args:
56
+ data_source (Dataset): dataset to sample from
57
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
58
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
59
+ generator (Generator): Generator used in sampling.
60
+ """
61
+
62
+ data_source: Sized
63
+ replacement: bool
64
+
65
+ def __init__(self, data_source: Sized, replacement: bool = False,
66
+ num_samples: Optional[int] = None, generator=None) -> None:
67
+ self.data_source = data_source
68
+ self.replacement = replacement
69
+ self._num_samples = num_samples
70
+ self.generator = generator
71
+ self._pos_start = 0
72
+
73
+ if not isinstance(self.replacement, bool):
74
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
75
+
76
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
77
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
78
+
79
+ @property
80
+ def num_samples(self) -> int:
81
+ # dataset size might change at runtime
82
+ if self._num_samples is None:
83
+ return len(self.data_source)
84
+ return self._num_samples
85
+
86
+ def __iter__(self) -> Iterator[int]:
87
+ n = len(self.data_source)
88
+ if self.generator is None:
89
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
90
+ generator = torch.Generator()
91
+ generator.manual_seed(seed)
92
+ else:
93
+ generator = self.generator
94
+
95
+ if self.replacement:
96
+ for _ in range(self.num_samples // 32):
97
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
98
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
99
+ else:
100
+ for _ in range(self.num_samples // n):
101
+ xx = torch.randperm(n, generator=generator).tolist()
102
+ if self._pos_start >= n:
103
+ self._pos_start = 0
104
+ print("xx top 10", xx[:10], self._pos_start)
105
+ for idx in range(self._pos_start, n):
106
+ yield xx[idx]
107
+ self._pos_start = (self._pos_start + 1) % n
108
+ self._pos_start = 0
109
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
110
+
111
+ def __len__(self) -> int:
112
+ return self.num_samples
113
+
114
+ class AspectRatioBatchImageSampler(BatchSampler):
115
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
116
+
117
+ Args:
118
+ sampler (Sampler): Base sampler.
119
+ dataset (Dataset): Dataset providing data information.
120
+ batch_size (int): Size of mini-batch.
121
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
122
+ its size would be less than ``batch_size``.
123
+ aspect_ratios (dict): The predefined aspect ratios.
124
+ """
125
+ def __init__(
126
+ self,
127
+ sampler: Sampler,
128
+ dataset: Dataset,
129
+ batch_size: int,
130
+ train_folder: str = None,
131
+ aspect_ratios: dict = ASPECT_RATIO_512,
132
+ drop_last: bool = False,
133
+ config=None,
134
+ **kwargs
135
+ ) -> None:
136
+ if not isinstance(sampler, Sampler):
137
+ raise TypeError('sampler should be an instance of ``Sampler``, '
138
+ f'but got {sampler}')
139
+ if not isinstance(batch_size, int) or batch_size <= 0:
140
+ raise ValueError('batch_size should be a positive integer value, '
141
+ f'but got batch_size={batch_size}')
142
+ self.sampler = sampler
143
+ self.dataset = dataset
144
+ self.train_folder = train_folder
145
+ self.batch_size = batch_size
146
+ self.aspect_ratios = aspect_ratios
147
+ self.drop_last = drop_last
148
+ self.config = config
149
+ # buckets for each aspect ratio
150
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
151
+ # [str(k) for k, v in aspect_ratios]
152
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
153
+
154
+ def __iter__(self):
155
+ for idx in self.sampler:
156
+ try:
157
+ image_dict = self.dataset[idx]
158
+
159
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
160
+ if width is None or height is None:
161
+ image_id, name = image_dict['file_path'], image_dict['text']
162
+ if self.train_folder is None:
163
+ image_dir = image_id
164
+ else:
165
+ image_dir = os.path.join(self.train_folder, image_id)
166
+
167
+ width, height = get_image_size_without_loading(image_dir)
168
+
169
+ ratio = height / width # self.dataset[idx]
170
+ else:
171
+ height = int(height)
172
+ width = int(width)
173
+ ratio = height / width # self.dataset[idx]
174
+ except Exception as e:
175
+ print(e)
176
+ continue
177
+ # find the closest aspect ratio
178
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
179
+ if closest_ratio not in self.current_available_bucket_keys:
180
+ continue
181
+ bucket = self._aspect_ratio_buckets[closest_ratio]
182
+ bucket.append(idx)
183
+ # yield a batch of indices in the same aspect ratio group
184
+ if len(bucket) == self.batch_size:
185
+ yield bucket[:]
186
+ del bucket[:]
187
+
188
+ class AspectRatioBatchSampler(BatchSampler):
189
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
190
+
191
+ Args:
192
+ sampler (Sampler): Base sampler.
193
+ dataset (Dataset): Dataset providing data information.
194
+ batch_size (int): Size of mini-batch.
195
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
196
+ its size would be less than ``batch_size``.
197
+ aspect_ratios (dict): The predefined aspect ratios.
198
+ """
199
+ def __init__(
200
+ self,
201
+ sampler: Sampler,
202
+ dataset: Dataset,
203
+ batch_size: int,
204
+ video_folder: str = None,
205
+ train_data_format: str = "webvid",
206
+ aspect_ratios: dict = ASPECT_RATIO_512,
207
+ drop_last: bool = False,
208
+ config=None,
209
+ **kwargs
210
+ ) -> None:
211
+ if not isinstance(sampler, Sampler):
212
+ raise TypeError('sampler should be an instance of ``Sampler``, '
213
+ f'but got {sampler}')
214
+ if not isinstance(batch_size, int) or batch_size <= 0:
215
+ raise ValueError('batch_size should be a positive integer value, '
216
+ f'but got batch_size={batch_size}')
217
+ self.sampler = sampler
218
+ self.dataset = dataset
219
+ self.video_folder = video_folder
220
+ self.train_data_format = train_data_format
221
+ self.batch_size = batch_size
222
+ self.aspect_ratios = aspect_ratios
223
+ self.drop_last = drop_last
224
+ self.config = config
225
+ # buckets for each aspect ratio
226
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
227
+ # [str(k) for k, v in aspect_ratios]
228
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
229
+
230
+ def __iter__(self):
231
+ for idx in self.sampler:
232
+ try:
233
+ video_dict = self.dataset[idx]
234
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
235
+
236
+ if width is None or height is None:
237
+ if self.train_data_format == "normal":
238
+ video_id, name = video_dict['file_path'], video_dict['text']
239
+ if self.video_folder is None:
240
+ video_dir = video_id
241
+ else:
242
+ video_dir = os.path.join(self.video_folder, video_id)
243
+ else:
244
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
245
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
246
+ cap = cv2.VideoCapture(video_dir)
247
+
248
+ # 获取视频尺寸
249
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
250
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
251
+
252
+ ratio = height / width # self.dataset[idx]
253
+ else:
254
+ height = int(height)
255
+ width = int(width)
256
+ ratio = height / width # self.dataset[idx]
257
+ except Exception as e:
258
+ print(e, self.dataset[idx], "This item is error, please check it.")
259
+ continue
260
+ # find the closest aspect ratio
261
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
262
+ if closest_ratio not in self.current_available_bucket_keys:
263
+ continue
264
+ bucket = self._aspect_ratio_buckets[closest_ratio]
265
+ bucket.append(idx)
266
+ # yield a batch of indices in the same aspect ratio group
267
+ if len(bucket) == self.batch_size:
268
+ yield bucket[:]
269
+ del bucket[:]
270
+
271
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
272
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
273
+
274
+ Args:
275
+ sampler (Sampler): Base sampler.
276
+ dataset (Dataset): Dataset providing data information.
277
+ batch_size (int): Size of mini-batch.
278
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
279
+ its size would be less than ``batch_size``.
280
+ aspect_ratios (dict): The predefined aspect ratios.
281
+ """
282
+
283
+ def __init__(self,
284
+ sampler: Sampler,
285
+ dataset: Dataset,
286
+ batch_size: int,
287
+ train_folder: str = None,
288
+ aspect_ratios: dict = ASPECT_RATIO_512,
289
+ drop_last: bool = False
290
+ ) -> None:
291
+ if not isinstance(sampler, Sampler):
292
+ raise TypeError('sampler should be an instance of ``Sampler``, '
293
+ f'but got {sampler}')
294
+ if not isinstance(batch_size, int) or batch_size <= 0:
295
+ raise ValueError('batch_size should be a positive integer value, '
296
+ f'but got batch_size={batch_size}')
297
+ self.sampler = sampler
298
+ self.dataset = dataset
299
+ self.train_folder = train_folder
300
+ self.batch_size = batch_size
301
+ self.aspect_ratios = aspect_ratios
302
+ self.drop_last = drop_last
303
+
304
+ # buckets for each aspect ratio
305
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
306
+ self.bucket = {
307
+ 'image':{ratio: [] for ratio in aspect_ratios},
308
+ 'video':{ratio: [] for ratio in aspect_ratios}
309
+ }
310
+
311
+ def __iter__(self):
312
+ for idx in self.sampler:
313
+ content_type = self.dataset[idx].get('type', 'image')
314
+ if content_type == 'image':
315
+ try:
316
+ image_dict = self.dataset[idx]
317
+
318
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
319
+ if width is None or height is None:
320
+ image_id, name = image_dict['file_path'], image_dict['text']
321
+ if self.train_folder is None:
322
+ image_dir = image_id
323
+ else:
324
+ image_dir = os.path.join(self.train_folder, image_id)
325
+
326
+ width, height = get_image_size_without_loading(image_dir)
327
+
328
+ ratio = height / width # self.dataset[idx]
329
+ else:
330
+ height = int(height)
331
+ width = int(width)
332
+ ratio = height / width # self.dataset[idx]
333
+ except Exception as e:
334
+ print(e, self.dataset[idx], "This item is error, please check it.")
335
+ continue
336
+ # find the closest aspect ratio
337
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
338
+ if closest_ratio not in self.current_available_bucket_keys:
339
+ continue
340
+ bucket = self.bucket['image'][closest_ratio]
341
+ bucket.append(idx)
342
+ # yield a batch of indices in the same aspect ratio group
343
+ if len(bucket) == self.batch_size:
344
+ yield bucket[:]
345
+ del bucket[:]
346
+ else:
347
+ try:
348
+ video_dict = self.dataset[idx]
349
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
350
+
351
+ if width is None or height is None:
352
+ if video_dict['type'] == 'video_mask_tuple':
353
+ video_dir = video_dict['file_path']
354
+ if os.path.isdir(os.path.join(video_dir, 'input')):
355
+ sample_path = list(glob.glob(os.path.join(video_dir, 'input', '*.png')))[0]
356
+ width, height = get_image_size_without_loading(sample_path)
357
+ else:
358
+ sample_path = os.path.join(video_dir, 'rgb_full.mp4')
359
+ cap = cv2.VideoCapture(sample_path)
360
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
361
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
362
+ else:
363
+ video_id, name = video_dict['file_path'], video_dict['text']
364
+ if self.train_folder is None:
365
+ video_dir = video_id
366
+ else:
367
+ video_dir = os.path.join(self.train_folder, video_id)
368
+ cap = cv2.VideoCapture(video_dir)
369
+
370
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
371
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
372
+
373
+ ratio = height / width # self.dataset[idx]
374
+ else:
375
+ height = int(height)
376
+ width = int(width)
377
+ ratio = height / width # self.dataset[idx]
378
+ except Exception as e:
379
+ print(e, self.dataset[idx], "This item is error, please check it.")
380
+ continue
381
+ # find the closest aspect ratio
382
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
383
+ if closest_ratio not in self.current_available_bucket_keys:
384
+ continue
385
+ bucket = self.bucket['video'][closest_ratio]
386
+ bucket.append(idx)
387
+ # yield a batch of indices in the same aspect ratio group
388
+ if len(bucket) == self.batch_size:
389
+ yield bucket[:]
390
+ del bucket[:]
videox_fun/data/dataset_image.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ if __name__ == "__main__":
69
+ dataset = CC15M(
70
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
71
+ resolution=512,
72
+ )
73
+
74
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
75
+ for idx, batch in enumerate(dataloader):
76
+ print(batch["pixel_values"].shape, len(batch["text"]))
videox_fun/data/dataset_image_video.py ADDED
@@ -0,0 +1,1067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import json
4
+ import math
5
+ import os
6
+ import glob
7
+ import random
8
+ from threading import Thread
9
+ import mediapy as media
10
+ import time
11
+
12
+ import albumentations
13
+ import cv2
14
+ import gc
15
+ import numpy as np
16
+ import torch
17
+ import torchvision.transforms as transforms
18
+ from scipy.special import binom
19
+
20
+ from func_timeout import func_timeout, FunctionTimedOut
21
+ from decord import VideoReader
22
+ from PIL import Image
23
+ from torch.utils.data import BatchSampler, Sampler
24
+ from torch.utils.data.dataset import Dataset
25
+ from contextlib import contextmanager
26
+
27
+ VIDEO_READER_TIMEOUT = 20
28
+
29
+ bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)
30
+
31
+ # codes from https://stackoverflow.com/questions/50731785/create-random-shape-contour-using-matplotlib
32
+ def bezier(points, num=200):
33
+ N = len(points)
34
+ t = np.linspace(0, 1, num=num)
35
+ curve = np.zeros((num, 2))
36
+ for i in range(N):
37
+ curve += np.outer(bernstein(N - 1, i, t), points[i])
38
+ return curve
39
+
40
+ class Segment():
41
+ def __init__(self, p1, p2, angle1, angle2, **kw):
42
+ self.p1 = p1
43
+ self.p2 = p2
44
+ self.angle1 = angle1
45
+ self.angle2 = angle2
46
+ self.numpoints = kw.get("numpoints", 100)
47
+ r = kw.get("r", 0.3)
48
+ d = np.sqrt(np.sum((self.p2-self.p1)**2))
49
+ self.r = r*d
50
+ self.p = np.zeros((4,2))
51
+ self.p[0,:] = self.p1[:]
52
+ self.p[3,:] = self.p2[:]
53
+ self.calc_intermediate_points(self.r)
54
+
55
+ def calc_intermediate_points(self,r):
56
+ self.p[1,:] = self.p1 + np.array(
57
+ [self.r*np.cos(self.angle1), self.r*np.sin(self.angle1)])
58
+ self.p[2,:] = self.p2 + np.array(
59
+ [self.r*np.cos(self.angle2+np.pi), self.r*np.sin(self.angle2+np.pi)])
60
+ self.curve = bezier(self.p,self.numpoints)
61
+
62
+
63
+ def get_curve(points, **kw):
64
+ segments = []
65
+ for i in range(len(points)-1):
66
+ seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
67
+ segments.append(seg)
68
+ curve = np.concatenate([s.curve for s in segments])
69
+ return segments, curve
70
+
71
+
72
+ def ccw_sort(p):
73
+ d = p-np.mean(p,axis=0)
74
+ s = np.arctan2(d[:,0], d[:,1])
75
+ return p[np.argsort(s),:]
76
+
77
+
78
+ def get_bezier_curve(a, rad=0.2, edgy=0):
79
+ """ given an array of points *a*, create a curve through
80
+ those points.
81
+ *rad* is a number between 0 and 1 to steer the distance of
82
+ control points.
83
+ *edgy* is a parameter which controls how "edgy" the curve is,
84
+ edgy=0 is smoothest."""
85
+ p = np.arctan(edgy)/np.pi+.5
86
+ a = ccw_sort(a)
87
+ a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
88
+ d = np.diff(a, axis=0)
89
+ ang = np.arctan2(d[:,1],d[:,0])
90
+ f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
91
+ ang = f(ang)
92
+ ang1 = ang
93
+ ang2 = np.roll(ang,1)
94
+ ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
95
+ ang = np.append(ang, [ang[0]])
96
+ a = np.append(a, np.atleast_2d(ang).T, axis=1)
97
+ s, c = get_curve(a, r=rad, method="var")
98
+ x,y = c.T
99
+ return x,y, a
100
+
101
+
102
+ def get_random_points(n=5, scale=0.8, mindst=None, rec=0):
103
+ """ create n random points in the unit square, which are *mindst*
104
+ apart, then scale them."""
105
+ mindst = mindst or .7/n
106
+ a = np.random.rand(n,2)
107
+ d = np.sqrt(np.sum(np.diff(ccw_sort(a), axis=0), axis=1)**2)
108
+ if np.all(d >= mindst) or rec>=200:
109
+ return a*scale
110
+ else:
111
+ return get_random_points(n=n, scale=scale, mindst=mindst, rec=rec+1)
112
+
113
+
114
+ def fill_mask(shape, x, y, fill_val=255):
115
+ _, _, h, w = shape
116
+ mask = np.zeros((h, w), dtype=np.uint8)
117
+ mask = cv2.fillPoly(mask, [np.array([x, y], np.int32).T], fill_val)
118
+ return mask
119
+
120
+
121
+ def random_shift(x, y, scale_range = [0.2, 0.7], trans_perturb_range=[-0.2, 0.2]):
122
+ w_scale = np.random.uniform(scale_range[0], scale_range[1])
123
+ h_scale = np.random.uniform(scale_range[0], scale_range[1])
124
+ x_trans = np.random.uniform(0., 1. - w_scale)
125
+ y_trans = np.random.uniform(0., 1. - h_scale)
126
+ x_shifted = x * w_scale + x_trans + np.random.uniform(trans_perturb_range[0], trans_perturb_range[1])
127
+ y_shifted = y * h_scale + y_trans + np.random.uniform(trans_perturb_range[0], trans_perturb_range[1])
128
+ return x_shifted, y_shifted
129
+
130
+
131
+ def get_random_shape_mask(
132
+ shape, n_pts_range=[3, 10], rad_range=[0.0, 1.0], edgy_range=[0.0, 0.1], n_keyframes_range=[2, 25],
133
+ random_drop_range=[0.0, 0.2],
134
+ ):
135
+ f, _, h, w = shape
136
+
137
+ n_pts = np.random.randint(n_pts_range[0], n_pts_range[1])
138
+ n_keyframes = np.random.randint(n_keyframes_range[0], n_keyframes_range[1])
139
+ keyframe_interval = f // (n_keyframes - 1)
140
+ keyframe_indices = list(range(0, f, keyframe_interval))
141
+ if len(keyframe_indices) == n_keyframes:
142
+ keyframe_indices[-1] = f - 1
143
+ else:
144
+ keyframe_indices.append(f - 1)
145
+ x_all_frames, y_all_frames = [], []
146
+ for i, keyframe_index in enumerate(keyframe_indices):
147
+ rad = np.random.uniform(rad_range[0], rad_range[1])
148
+ edgy = np.random.uniform(edgy_range[0], edgy_range[1])
149
+ x_kf, y_kf, _ = get_bezier_curve(get_random_points(n=n_pts), rad=rad, edgy=edgy)
150
+ x_kf, y_kf = random_shift(x_kf, y_kf)
151
+ if i == 0:
152
+ x_all_frames.append(x_kf[None])
153
+ y_all_frames.append(y_kf[None])
154
+ else:
155
+ x_interval = np.linspace(x_all_frames[-1][-1], x_kf, keyframe_index - keyframe_indices[i - 1] + 1)
156
+ y_interval = np.linspace(y_all_frames[-1][-1], y_kf, keyframe_index - keyframe_indices[i - 1] + 1)
157
+ x_all_frames.append(x_interval[1:])
158
+ y_all_frames.append(y_interval[1:])
159
+ x_all_frames = np.concatenate(x_all_frames, axis=0)
160
+ y_all_frames = np.concatenate(y_all_frames, axis=0)
161
+
162
+ masks = []
163
+ for x, y in zip(x_all_frames, y_all_frames):
164
+ x = np.round(x * w).astype(np.int32)
165
+ y = np.round(y * h).astype(np.int32)
166
+ mask = fill_mask(shape, x, y)
167
+ masks.append(mask)
168
+ masks = np.stack(masks, axis=0).astype(float) / 255.
169
+
170
+ n_frames_random_drop = int(np.random.uniform(random_drop_range[0], random_drop_range[1]) * f)
171
+ drop_index = np.random.randint(0, f - n_frames_random_drop)
172
+ masks[drop_index:drop_index + n_frames_random_drop] = 0
173
+
174
+ return masks # (f, h, w), <float>[0, 1]
175
+
176
+
177
+ def get_random_mask(shape, mask_type_probs=[0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.8]):
178
+ f, c, h, w = shape
179
+
180
+ if f != 1:
181
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], p=mask_type_probs)
182
+ else:
183
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
184
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
185
+
186
+ if mask_index == 0:
187
+ center_x = torch.randint(0, w, (1,)).item()
188
+ center_y = torch.randint(0, h, (1,)).item()
189
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()
190
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()
191
+
192
+ start_x = max(center_x - block_size_x // 2, 0)
193
+ end_x = min(center_x + block_size_x // 2, w)
194
+ start_y = max(center_y - block_size_y // 2, 0)
195
+ end_y = min(center_y + block_size_y // 2, h)
196
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
197
+ elif mask_index == 1:
198
+ mask[:, :, :, :] = 1
199
+ elif mask_index == 2:
200
+ mask_frame_index = np.random.randint(1, 5)
201
+ mask[mask_frame_index:, :, :, :] = 1
202
+ elif mask_index == 3:
203
+ mask_frame_index = np.random.randint(1, 5)
204
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
205
+ elif mask_index == 4:
206
+ center_x = torch.randint(0, w, (1,)).item()
207
+ center_y = torch.randint(0, h, (1,)).item()
208
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()
209
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()
210
+
211
+ start_x = max(center_x - block_size_x // 2, 0)
212
+ end_x = min(center_x + block_size_x // 2, w)
213
+ start_y = max(center_y - block_size_y // 2, 0)
214
+ end_y = min(center_y + block_size_y // 2, h)
215
+
216
+ mask_frame_before = np.random.randint(0, f // 2)
217
+ mask_frame_after = np.random.randint(f // 2, f)
218
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
219
+ elif mask_index == 5:
220
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
221
+ elif mask_index == 6:
222
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
223
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
224
+
225
+ for i in frames_to_mask:
226
+ block_height = random.randint(1, h // 4)
227
+ block_width = random.randint(1, w // 4)
228
+ top_left_y = random.randint(0, h - block_height)
229
+ top_left_x = random.randint(0, w - block_width)
230
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
231
+ elif mask_index == 7:
232
+ center_x = torch.randint(0, w, (1,)).item()
233
+ center_y = torch.randint(0, h, (1,)).item()
234
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item()
235
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
236
+
237
+ for i in range(h):
238
+ for j in range(w):
239
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
240
+ mask[:, :, i, j] = 1
241
+ elif mask_index == 8:
242
+ center_x = torch.randint(0, w, (1,)).item()
243
+ center_y = torch.randint(0, h, (1,)).item()
244
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
245
+ for i in range(h):
246
+ for j in range(w):
247
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
248
+ mask[:, :, i, j] = 1
249
+ elif mask_index == 9:
250
+ for idx in range(f):
251
+ if np.random.rand() > 0.5:
252
+ mask[idx, :, :, :] = 1
253
+ else:
254
+ num_objs = np.random.randint(1, 4)
255
+ mask_npy = get_random_shape_mask(shape)
256
+ for i in range(num_objs - 1):
257
+ mask_npy += get_random_shape_mask(shape).clip(0, 1)
258
+
259
+ mask = torch.from_numpy(mask_npy).unsqueeze(1)
260
+
261
+ return mask.float()
262
+
263
+
264
+ def get_random_mask_multi(shape, mask_type_probs, range_num_masks=[1, 7]):
265
+ num_masks = np.random.randint(range_num_masks[0], range_num_masks[1])
266
+ masks = None
267
+ for _ in range(num_masks):
268
+ mask = get_random_mask(shape, mask_type_probs)
269
+ if masks is None:
270
+ masks = mask
271
+ else:
272
+ masks = (masks + mask).clip(0, 1)
273
+ return masks
274
+
275
+
276
+ class ImageVideoSampler(BatchSampler):
277
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
278
+
279
+ Args:
280
+ sampler (Sampler): Base sampler.
281
+ dataset (Dataset): Dataset providing data information.
282
+ batch_size (int): Size of mini-batch.
283
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
284
+ its size would be less than ``batch_size``.
285
+ aspect_ratios (dict): The predefined aspect ratios.
286
+ """
287
+
288
+ def __init__(self,
289
+ sampler: Sampler,
290
+ dataset: Dataset,
291
+ batch_size: int,
292
+ drop_last: bool = False
293
+ ) -> None:
294
+ if not isinstance(sampler, Sampler):
295
+ raise TypeError('sampler should be an instance of ``Sampler``, '
296
+ f'but got {sampler}')
297
+ if not isinstance(batch_size, int) or batch_size <= 0:
298
+ raise ValueError('batch_size should be a positive integer value, '
299
+ f'but got batch_size={batch_size}')
300
+ self.sampler = sampler
301
+ self.dataset = dataset
302
+ self.batch_size = batch_size
303
+ self.drop_last = drop_last
304
+
305
+ # buckets for each aspect ratio
306
+ self.bucket = {'image':[], 'video':[], 'video_mask_tuple':[]}
307
+
308
+ def __iter__(self):
309
+ for idx in self.sampler:
310
+ content_type = self.dataset.dataset[idx].get('type', 'image')
311
+ self.bucket[content_type].append(idx)
312
+
313
+ # yield a batch of indices in the same aspect ratio group
314
+ if len(self.bucket['video']) == self.batch_size:
315
+ bucket = self.bucket['video']
316
+ yield bucket[:]
317
+ del bucket[:]
318
+ elif len(self.bucket['video_mask_tuple']) == self.batch_size:
319
+ bucket = self.bucket['video_mask_tuple']
320
+ yield bucket[:]
321
+ del bucket[:]
322
+ elif len(self.bucket['image']) == self.batch_size:
323
+ bucket = self.bucket['image']
324
+ yield bucket[:]
325
+ del bucket[:]
326
+
327
+
328
+ @contextmanager
329
+ def VideoReader_contextmanager(*args, **kwargs):
330
+ vr = VideoReader(*args, **kwargs)
331
+ try:
332
+ yield vr
333
+ finally:
334
+ del vr
335
+ gc.collect()
336
+
337
+
338
+ def get_video_reader_batch(video_reader, batch_index):
339
+ frames = video_reader.get_batch(batch_index).asnumpy()
340
+ return frames
341
+
342
+
343
+ def _read_video_from_dir(video_dir):
344
+ frames = []
345
+ frame_paths = sorted(list(glob.glob(os.path.join(video_dir, '*.png'))))
346
+
347
+ if not frame_paths:
348
+ raise ValueError(f"No PNG files found in directory: {video_dir}")
349
+
350
+ for frame_path in frame_paths:
351
+ frame = media.read_image(frame_path)
352
+ frames.append(frame)
353
+
354
+ if not frames:
355
+ raise ValueError(f"Failed to read any frames from directory: {video_dir}")
356
+
357
+ return np.stack(frames, axis=0)
358
+
359
+
360
+ def resize_frame(frame, target_short_side):
361
+ h, w, _ = frame.shape
362
+ if h < w:
363
+ if target_short_side > h:
364
+ return frame
365
+ new_h = target_short_side
366
+ new_w = int(target_short_side * w / h)
367
+ else:
368
+ if target_short_side > w:
369
+ return frame
370
+ new_w = target_short_side
371
+ new_h = int(target_short_side * h / w)
372
+
373
+ resized_frame = cv2.resize(frame, (new_w, new_h))
374
+ return resized_frame
375
+
376
+
377
+ class ImageVideoDataset(Dataset):
378
+ def __init__(
379
+ self,
380
+ ann_path, data_root=None,
381
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
382
+ image_sample_size=512,
383
+ video_repeat=0,
384
+ text_drop_ratio=0.1,
385
+ enable_bucket=False,
386
+ video_length_drop_start=0.0,
387
+ video_length_drop_end=1.0,
388
+ enable_inpaint=False,
389
+ trimask_zeroout_removal=False,
390
+ use_quadmask=False,
391
+ ablation_binary_mask=False,
392
+ ):
393
+ # Loading annotations from files
394
+ print(f"loading annotations from {ann_path} ...")
395
+ if ann_path.endswith('.csv'):
396
+ with open(ann_path, 'r') as csvfile:
397
+ dataset = list(csv.DictReader(csvfile))
398
+ elif ann_path.endswith('.json'):
399
+ dataset = json.load(open(ann_path))
400
+ else:
401
+ raise ValueError(f"Unsupported annotation file format: {ann_path}. Only .csv and .json files are supported.")
402
+
403
+ self.data_root = data_root
404
+
405
+ # It's used to balance num of images and videos.
406
+ self.dataset = []
407
+ for data in dataset:
408
+ if data.get('type', 'image') != 'video':
409
+ self.dataset.append(data)
410
+ if video_repeat > 0:
411
+ for _ in range(video_repeat):
412
+ for data in dataset:
413
+ if data.get('type', 'image') == 'video':
414
+ self.dataset.append(data)
415
+ del dataset
416
+
417
+ self.length = len(self.dataset)
418
+ print(f"data scale: {self.length}")
419
+ # TODO: enable bucket training
420
+ self.enable_bucket = enable_bucket
421
+ self.text_drop_ratio = text_drop_ratio
422
+ self.enable_inpaint = enable_inpaint
423
+ self.trimask_zeroout_removal = trimask_zeroout_removal
424
+ self.use_quadmask = use_quadmask
425
+ self.ablation_binary_mask = ablation_binary_mask
426
+
427
+ self.video_length_drop_start = video_length_drop_start
428
+ self.video_length_drop_end = video_length_drop_end
429
+
430
+ if self.use_quadmask:
431
+ print(f"[QUADMASK MODE] Using 4-value quadmask: [0, 63, 127, 255]")
432
+ if self.ablation_binary_mask:
433
+ print(f"[ABLATION BINARY MASK] Remapping quadmask to binary: [0,63]→0, [127,255]→127")
434
+ else:
435
+ print(f"[TRIMASK MODE] Using 3-value trimask: [0, 127, 255]")
436
+
437
+ # Video params
438
+ self.video_sample_stride = video_sample_stride
439
+ self.video_sample_n_frames = video_sample_n_frames
440
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
441
+ self.video_transforms = transforms.Compose(
442
+ [
443
+ transforms.Resize(min(self.video_sample_size)),
444
+ transforms.CenterCrop(self.video_sample_size),
445
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
446
+ ]
447
+ )
448
+
449
+ # Image params
450
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
451
+ self.image_transforms = transforms.Compose([
452
+ transforms.Resize(min(self.image_sample_size)),
453
+ transforms.CenterCrop(self.image_sample_size),
454
+ transforms.ToTensor(),
455
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
456
+ ])
457
+
458
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
459
+
460
+ def get_batch(self, idx):
461
+ data_info = self.dataset[idx % len(self.dataset)]
462
+
463
+ if data_info.get('type', 'image') == 'video' and data_info.get('mask_path', None) is None:
464
+ video_id, text = data_info['file_path'], data_info['text']
465
+
466
+ if self.data_root is None:
467
+ video_dir = video_id
468
+ else:
469
+ video_dir = os.path.join(self.data_root, video_id)
470
+
471
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
472
+ min_sample_n_frames = min(
473
+ self.video_sample_n_frames,
474
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
475
+ )
476
+ if min_sample_n_frames == 0:
477
+ raise ValueError(f"No Frames in video.")
478
+
479
+ video_length = int(self.video_length_drop_end * len(video_reader))
480
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
481
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
482
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
483
+
484
+ try:
485
+ sample_args = (video_reader, batch_index)
486
+ pixel_values = func_timeout(
487
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
488
+ )
489
+ resized_frames = []
490
+ for i in range(len(pixel_values)):
491
+ frame = pixel_values[i]
492
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
493
+ resized_frames.append(resized_frame)
494
+ pixel_values = np.array(resized_frames)
495
+ except FunctionTimedOut:
496
+ raise ValueError(f"Read {idx} timeout.")
497
+ except Exception as e:
498
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
499
+
500
+ if not self.enable_bucket:
501
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
502
+ pixel_values = pixel_values / 255.
503
+ del video_reader
504
+ else:
505
+ pixel_values = pixel_values
506
+
507
+ if not self.enable_bucket:
508
+ pixel_values = self.video_transforms(pixel_values)
509
+
510
+ # Random use no text generation
511
+ if random.random() < self.text_drop_ratio:
512
+ text = ''
513
+ return {
514
+ 'pixel_values': pixel_values,
515
+ 'text': text,
516
+ 'data_type': 'video',
517
+ }
518
+ elif data_info.get('type', 'image') == 'video' and data_info.get('mask_path', None) is not None: # video with known mask
519
+ video_path, text = data_info['file_path'], data_info['text']
520
+ mask_video_path = video_path[:-4] + '_mask.mp4'
521
+ with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
522
+ min_sample_n_frames = min(
523
+ self.video_sample_n_frames,
524
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
525
+ )
526
+ if min_sample_n_frames == 0:
527
+ raise ValueError(f"No Frames in video.")
528
+
529
+ video_length = int(self.video_length_drop_end * len(video_reader))
530
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
531
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
532
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
533
+
534
+ try:
535
+ sample_args = (video_reader, batch_index)
536
+ pixel_values = func_timeout(
537
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
538
+ )
539
+ resized_frames = []
540
+ for i in range(len(pixel_values)):
541
+ frame = pixel_values[i]
542
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
543
+ resized_frames.append(resized_frame)
544
+ input_video = np.array(resized_frames)
545
+ except FunctionTimedOut:
546
+ raise ValueError(f"Read {idx} timeout.")
547
+ except Exception as e:
548
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
549
+
550
+ with VideoReader_contextmanager(mask_video_path, num_threads=2) as video_reader:
551
+ try:
552
+ sample_args = (video_reader, batch_index)
553
+ mask_values = func_timeout(
554
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
555
+ )
556
+ resized_frames = []
557
+ for i in range(len(mask_values)):
558
+ frame = mask_values[i]
559
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
560
+ resized_frames.append(resized_frame)
561
+ mask_video = np.array(resized_frames)
562
+ except FunctionTimedOut:
563
+ raise ValueError(f"Read {idx} timeout.")
564
+ except Exception as e:
565
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
566
+
567
+ if len(mask_video.shape) == 3:
568
+ mask_video = mask_video[..., None]
569
+ if mask_video.shape[-1] == 3:
570
+ mask_video = mask_video[..., :1]
571
+ if len(mask_video.shape) != 4:
572
+ raise ValueError(f"mask_video shape is {mask_video.shape}.")
573
+
574
+ text = data_info['text']
575
+ if not self.enable_bucket:
576
+ input_video = torch.from_numpy(input_video).permute(0, 3, 1, 2).contiguous() / 255.
577
+ mask_video = torch.from_numpy(mask_video).permute(0, 3, 1, 2).contiguous() / 255.
578
+
579
+ pixel_values = torch.cat([input_video, mask_video], dim=1)
580
+ pixel_values = self.video_transforms(pixel_values)
581
+ input_video = pixel_values[:, :3]
582
+ mask_video = pixel_values[:, 3:]
583
+
584
+ # Random use no text generation
585
+ if random.random() < self.text_drop_ratio:
586
+ text = ''
587
+
588
+ return {
589
+ 'pixel_values': input_video,
590
+ 'mask': mask_video,
591
+ 'text': text,
592
+ 'data_type': 'video',
593
+ }
594
+
595
+ elif data_info.get('type', 'image') == 'video_mask_tuple': # object effect removal
596
+ sample_dir = data_info['file_path']
597
+ try:
598
+ if os.path.exists(os.path.join(sample_dir, 'rgb_full.mp4')):
599
+ input_video_path = os.path.join(sample_dir, 'rgb_full.mp4')
600
+ target_video_path = os.path.join(sample_dir, 'rgb_removed.mp4')
601
+ mask_video_path = os.path.join(sample_dir, 'mask.mp4')
602
+ depth_video_path = os.path.join(sample_dir, 'depth_removed.mp4')
603
+
604
+ input_video = media.read_video(input_video_path)
605
+ target_video = media.read_video(target_video_path)
606
+ mask_video = media.read_video(mask_video_path)
607
+
608
+ # Load depth map if it exists
609
+ depth_video = None
610
+ if os.path.exists(depth_video_path):
611
+ depth_video = media.read_video(depth_video_path)
612
+
613
+ else:
614
+ input_video_path = os.path.join(sample_dir, 'input')
615
+ target_video_path = os.path.join(sample_dir, 'bg')
616
+ mask_video_path = os.path.join(sample_dir, 'trimask')
617
+
618
+ input_video = _read_video_from_dir(input_video_path)
619
+ target_video = _read_video_from_dir(target_video_path)
620
+ mask_video = _read_video_from_dir(mask_video_path)
621
+
622
+ # Initialize depth_video as None for this path
623
+ depth_video = None
624
+ except Exception as e:
625
+ print(f"Error loading video_mask_tuple from {sample_dir}: {e}")
626
+ import traceback
627
+ traceback.print_exc()
628
+ raise
629
+
630
+ mask_video = 255 - mask_video # will be flipped again in when feeding to model
631
+
632
+ if len(mask_video.shape) == 3:
633
+ mask_video = mask_video[..., None]
634
+ if mask_video.shape[-1] == 3:
635
+ mask_video = mask_video[..., :1]
636
+ min_sample_n_frames = min(
637
+ self.video_sample_n_frames,
638
+ int(len(input_video) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
639
+ )
640
+ video_length = int(self.video_length_drop_end * len(input_video))
641
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
642
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
643
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
644
+ input_video = input_video[batch_index]
645
+ target_video = target_video[batch_index]
646
+ mask_video = mask_video[batch_index]
647
+ if depth_video is not None:
648
+ depth_video = depth_video[batch_index]
649
+
650
+ resized_inputs = []
651
+ resized_targets = []
652
+ resized_masks = []
653
+ resized_depths = []
654
+ for i in range(len(input_video)):
655
+ resized_input = resize_frame(input_video[i], self.larger_side_of_image_and_video)
656
+ resized_target = resize_frame(target_video[i], self.larger_side_of_image_and_video)
657
+ resized_mask = resize_frame(mask_video[i], self.larger_side_of_image_and_video)
658
+
659
+ # Apply mask quantization based on mode
660
+ if self.ablation_binary_mask:
661
+ # Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127]
662
+ # Map 0 and 63 → 0
663
+ # Map 127 and 255 → 127
664
+ resized_mask = np.where(resized_mask <= 95, 0, resized_mask)
665
+ resized_mask = np.where(resized_mask > 95, 127, resized_mask)
666
+ elif self.use_quadmask:
667
+ # Quadmask mode: preserve 4 values [0, 63, 127, 255]
668
+ # Quantize to nearest quadmask value for robustness
669
+ resized_mask = np.where(resized_mask <= 31, 0, resized_mask)
670
+ resized_mask = np.where(np.logical_and(resized_mask > 31, resized_mask <= 95), 63, resized_mask)
671
+ resized_mask = np.where(np.logical_and(resized_mask > 95, resized_mask <= 191), 127, resized_mask)
672
+ resized_mask = np.where(resized_mask > 191, 255, resized_mask)
673
+ else:
674
+ # Trimask mode: 3 values [0, 127, 255]
675
+ resized_mask = np.where(np.logical_and(resized_mask > 63, resized_mask < 192), 127, resized_mask)
676
+ resized_mask = np.where(resized_mask >= 192, 255, resized_mask)
677
+ resized_mask = np.where(resized_mask <= 63, 0, resized_mask)
678
+
679
+ resized_inputs.append(resized_input)
680
+ resized_targets.append(resized_target)
681
+ resized_masks.append(resized_mask)
682
+
683
+ if depth_video is not None:
684
+ resized_depth = resize_frame(depth_video[i], self.larger_side_of_image_and_video)
685
+ resized_depths.append(resized_depth)
686
+
687
+ input_video = np.array(resized_inputs)
688
+ target_video = np.array(resized_targets)
689
+ mask_video = np.array(resized_masks)
690
+ if depth_video is not None:
691
+ depth_video = np.array(resized_depths)
692
+
693
+ if len(mask_video.shape) == 3:
694
+ mask_video = mask_video[..., None]
695
+ if mask_video.shape[-1] == 3:
696
+ mask_video = mask_video[..., :1]
697
+ if len(mask_video.shape) != 4:
698
+ raise ValueError(f"mask_video shape is {mask_video.shape}.")
699
+
700
+ text = data_info['text']
701
+ print(f"DEBUG DATASET: Converting to tensors (enable_bucket={self.enable_bucket})...")
702
+ if not self.enable_bucket:
703
+ print(f"DEBUG DATASET: Converting input_video to tensor...")
704
+ input_video = torch.from_numpy(input_video).permute(0, 3, 1, 2).contiguous() / 255.
705
+ print(f"DEBUG DATASET: Converting target_video to tensor...")
706
+ target_video = torch.from_numpy(target_video).permute(0, 3, 1, 2).contiguous() / 255.
707
+ print(f"DEBUG DATASET: Converting mask_video to tensor...")
708
+ mask_video = torch.from_numpy(mask_video).permute(0, 3, 1, 2).contiguous() / 255.
709
+
710
+ # Process depth video if available
711
+ if depth_video is not None:
712
+ print(f"DEBUG DATASET: Processing depth_video...")
713
+ # IMPORTANT: Copy depth_video to ensure it's not memory-mapped
714
+ # Memory-mapped files can cause bus errors on GPU transfer
715
+ print(f"DEBUG DATASET: Copying depth_video to ensure not memory-mapped...")
716
+ depth_video = np.array(depth_video, copy=True)
717
+ print(f"DEBUG DATASET: depth_video copied, shape={depth_video.shape}")
718
+
719
+ # Ensure depth has correct shape
720
+ if len(depth_video.shape) == 3:
721
+ depth_video = depth_video[..., None]
722
+ if depth_video.shape[-1] == 3:
723
+ # Convert to grayscale if RGB
724
+ print(f"DEBUG DATASET: Converting depth to grayscale...")
725
+ depth_video = depth_video.mean(axis=-1, keepdims=True)
726
+ # Convert to tensor [F, 1, H, W] and normalize to [0, 1]
727
+ print(f"DEBUG DATASET: Converting depth to tensor...")
728
+ depth_video = torch.from_numpy(depth_video).permute(0, 3, 1, 2).contiguous().float() / 255.
729
+ # Ensure tensor is contiguous and owned
730
+ print(f"DEBUG DATASET: Cloning depth tensor...")
731
+ depth_video = depth_video.clone().contiguous()
732
+ print(f"DEBUG DATASET: depth_video final shape: {depth_video.shape}, is_contiguous: {depth_video.is_contiguous()}")
733
+
734
+ # Apply transforms to each video separately (they expect 3 channels)
735
+ print(f"DEBUG DATASET: Applying video transforms...")
736
+ input_video = self.video_transforms(input_video)
737
+ target_video = self.video_transforms(target_video)
738
+ # Don't normalize mask since it's single channel
739
+ print(f"DEBUG DATASET: Normalizing mask_video...")
740
+ mask_video = mask_video * 2.0 - 1.0 # Scale to [-1, 1] like other channels
741
+ print(f"DEBUG DATASET: All tensors ready (non-bucket mode)")
742
+
743
+ else:
744
+ # For bucket mode, keep as numpy until collate
745
+ # Collate function expects [0, 255] range and will normalize
746
+ print(f"DEBUG DATASET: Bucket mode - keeping as numpy in [0, 255] range...")
747
+ print(f"DEBUG DATASET: All numpy arrays ready (bucket mode)")
748
+
749
+ # Random use no text generation
750
+ if random.random() < self.text_drop_ratio:
751
+ text = ''
752
+
753
+ if self.trimask_zeroout_removal:
754
+ input_video = input_video * np.where(mask_video > 200, 0, 1).astype(input_video.dtype)
755
+
756
+ result = {
757
+ 'pixel_values': target_video,
758
+ 'input_condition': input_video,
759
+ 'mask': mask_video,
760
+ 'text': text,
761
+ 'data_type': 'video_mask_tuple',
762
+ }
763
+
764
+ # Add depth maps if available
765
+ if depth_video is not None:
766
+ result['depth_maps'] = depth_video
767
+
768
+ return result
769
+
770
+ else:
771
+ image_path, text = data_info['file_path'], data_info['text']
772
+ if self.data_root is not None:
773
+ image_path = os.path.join(self.data_root, image_path)
774
+ image = Image.open(image_path).convert('RGB')
775
+ if not self.enable_bucket:
776
+ image = self.image_transforms(image).unsqueeze(0)
777
+ else:
778
+ image = np.expand_dims(np.array(image), 0)
779
+ if random.random() < self.text_drop_ratio:
780
+ text = ''
781
+ return {
782
+ 'pixel_values': image,
783
+ 'text': text,
784
+ 'data_type': 'image',
785
+ }
786
+
787
+ def __len__(self):
788
+ return self.length
789
+
790
+ def __getitem__(self, idx):
791
+ data_info = self.dataset[idx % len(self.dataset)]
792
+ data_type = data_info.get('type', 'image')
793
+ while True:
794
+ sample = {}
795
+ try:
796
+ data_info_local = self.dataset[idx % len(self.dataset)]
797
+ data_type_local = data_info_local.get('type', 'image')
798
+ if data_type_local != data_type:
799
+ raise ValueError("data_type_local != data_type")
800
+
801
+ sample = self.get_batch(idx)
802
+ sample["idx"] = idx
803
+
804
+ if len(sample) > 0:
805
+ break
806
+ except Exception as e:
807
+ import traceback
808
+ print(f"Error loading sample at index {idx}:")
809
+ print(f"Data info: {self.dataset[idx % len(self.dataset)]}")
810
+ print(f"Error: {e}")
811
+ traceback.print_exc()
812
+ idx = random.randint(0, self.length-1)
813
+
814
+ if self.enable_inpaint and not self.enable_bucket:
815
+ if "mask" not in sample:
816
+ mask = get_random_mask_multi(sample["pixel_values"].size())
817
+ sample["mask"] = mask
818
+ else:
819
+ mask = sample["mask"]
820
+
821
+ if "input_condition" in sample:
822
+ mask_pixel_values = sample["input_condition"]
823
+ else:
824
+ mask_pixel_values = sample["pixel_values"]
825
+ mask_pixel_values = mask_pixel_values * (1 - mask) + torch.ones_like(mask_pixel_values) * -1 * mask
826
+
827
+ sample["mask_pixel_values"] = mask_pixel_values
828
+
829
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
830
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
831
+ sample["clip_pixel_values"] = clip_pixel_values
832
+
833
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
834
+ if (mask == 1).all():
835
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
836
+ sample["ref_pixel_values"] = ref_pixel_values
837
+
838
+ return sample
839
+
840
+
841
+ class ImageVideoControlDataset(Dataset):
842
+ def __init__(
843
+ self,
844
+ ann_path, data_root=None,
845
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
846
+ image_sample_size=512,
847
+ video_repeat=0,
848
+ text_drop_ratio=0.1,
849
+ enable_bucket=False,
850
+ video_length_drop_start=0.0,
851
+ video_length_drop_end=1.0,
852
+ enable_inpaint=False,
853
+ ):
854
+ # Loading annotations from files
855
+ print(f"loading annotations from {ann_path} ...")
856
+ if ann_path.endswith('.csv'):
857
+ with open(ann_path, 'r') as csvfile:
858
+ dataset = list(csv.DictReader(csvfile))
859
+ elif ann_path.endswith('.json'):
860
+ dataset = json.load(open(ann_path))
861
+ else:
862
+ raise ValueError(f"Unsupported annotation file format: {ann_path}. Only .csv and .json files are supported.")
863
+
864
+ self.data_root = data_root
865
+
866
+ # It's used to balance num of images and videos.
867
+ self.dataset = []
868
+ for data in dataset:
869
+ if data.get('type', 'image') != 'video':
870
+ self.dataset.append(data)
871
+ if video_repeat > 0:
872
+ for _ in range(video_repeat):
873
+ for data in dataset:
874
+ if data.get('type', 'image') == 'video':
875
+ self.dataset.append(data)
876
+ del dataset
877
+
878
+ self.length = len(self.dataset)
879
+ print(f"data scale: {self.length}")
880
+ # TODO: enable bucket training
881
+ self.enable_bucket = enable_bucket
882
+ self.text_drop_ratio = text_drop_ratio
883
+ self.enable_inpaint = enable_inpaint
884
+
885
+ self.video_length_drop_start = video_length_drop_start
886
+ self.video_length_drop_end = video_length_drop_end
887
+
888
+ # Video params
889
+ self.video_sample_stride = video_sample_stride
890
+ self.video_sample_n_frames = video_sample_n_frames
891
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
892
+ self.video_transforms = transforms.Compose(
893
+ [
894
+ transforms.Resize(min(self.video_sample_size)),
895
+ transforms.CenterCrop(self.video_sample_size),
896
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
897
+ ]
898
+ )
899
+
900
+ # Image params
901
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
902
+ self.image_transforms = transforms.Compose([
903
+ transforms.Resize(min(self.image_sample_size)),
904
+ transforms.CenterCrop(self.image_sample_size),
905
+ transforms.ToTensor(),
906
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
907
+ ])
908
+
909
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
910
+
911
+ def get_batch(self, idx):
912
+ data_info = self.dataset[idx % len(self.dataset)]
913
+ video_id, text = data_info['file_path'], data_info['text']
914
+
915
+ if data_info.get('type', 'image')=='video':
916
+ if self.data_root is None:
917
+ video_dir = video_id
918
+ else:
919
+ video_dir = os.path.join(self.data_root, video_id)
920
+
921
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
922
+ min_sample_n_frames = min(
923
+ self.video_sample_n_frames,
924
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
925
+ )
926
+ if min_sample_n_frames == 0:
927
+ raise ValueError(f"No Frames in video.")
928
+
929
+ video_length = int(self.video_length_drop_end * len(video_reader))
930
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
931
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
932
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
933
+
934
+ try:
935
+ sample_args = (video_reader, batch_index)
936
+ pixel_values = func_timeout(
937
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
938
+ )
939
+ resized_frames = []
940
+ for i in range(len(pixel_values)):
941
+ frame = pixel_values[i]
942
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
943
+ resized_frames.append(resized_frame)
944
+ pixel_values = np.array(resized_frames)
945
+ except FunctionTimedOut:
946
+ raise ValueError(f"Read {idx} timeout.")
947
+ except Exception as e:
948
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
949
+
950
+ if not self.enable_bucket:
951
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
952
+ pixel_values = pixel_values / 255.
953
+ del video_reader
954
+ else:
955
+ pixel_values = pixel_values
956
+
957
+ if not self.enable_bucket:
958
+ pixel_values = self.video_transforms(pixel_values)
959
+
960
+ # Random use no text generation
961
+ if random.random() < self.text_drop_ratio:
962
+ text = ''
963
+
964
+ control_video_id = data_info['control_file_path']
965
+
966
+ if self.data_root is None:
967
+ control_video_id = control_video_id
968
+ else:
969
+ control_video_id = os.path.join(self.data_root, control_video_id)
970
+
971
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
972
+ try:
973
+ sample_args = (control_video_reader, batch_index)
974
+ control_pixel_values = func_timeout(
975
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
976
+ )
977
+ resized_frames = []
978
+ for i in range(len(control_pixel_values)):
979
+ frame = control_pixel_values[i]
980
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
981
+ resized_frames.append(resized_frame)
982
+ control_pixel_values = np.array(resized_frames)
983
+ except FunctionTimedOut:
984
+ raise ValueError(f"Read {idx} timeout.")
985
+ except Exception as e:
986
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
987
+
988
+ if not self.enable_bucket:
989
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
990
+ control_pixel_values = control_pixel_values / 255.
991
+ del control_video_reader
992
+ else:
993
+ control_pixel_values = control_pixel_values
994
+
995
+ if not self.enable_bucket:
996
+ control_pixel_values = self.video_transforms(control_pixel_values)
997
+ return pixel_values, control_pixel_values, text, "video"
998
+ else:
999
+ image_path, text = data_info['file_path'], data_info['text']
1000
+ if self.data_root is not None:
1001
+ image_path = os.path.join(self.data_root, image_path)
1002
+ image = Image.open(image_path).convert('RGB')
1003
+ if not self.enable_bucket:
1004
+ image = self.image_transforms(image).unsqueeze(0)
1005
+ else:
1006
+ image = np.expand_dims(np.array(image), 0)
1007
+
1008
+ if random.random() < self.text_drop_ratio:
1009
+ text = ''
1010
+
1011
+ control_image_id = data_info['control_file_path']
1012
+
1013
+ if self.data_root is None:
1014
+ control_image_id = control_image_id
1015
+ else:
1016
+ control_image_id = os.path.join(self.data_root, control_image_id)
1017
+
1018
+ control_image = Image.open(control_image_id).convert('RGB')
1019
+ if not self.enable_bucket:
1020
+ control_image = self.image_transforms(control_image).unsqueeze(0)
1021
+ else:
1022
+ control_image = np.expand_dims(np.array(control_image), 0)
1023
+ return image, control_image, text, 'image'
1024
+
1025
+ def __len__(self):
1026
+ return self.length
1027
+
1028
+ def __getitem__(self, idx):
1029
+ data_info = self.dataset[idx % len(self.dataset)]
1030
+ data_type = data_info.get('type', 'image')
1031
+ while True:
1032
+ sample = {}
1033
+ try:
1034
+ data_info_local = self.dataset[idx % len(self.dataset)]
1035
+ data_type_local = data_info_local.get('type', 'image')
1036
+ if data_type_local != data_type:
1037
+ raise ValueError("data_type_local != data_type")
1038
+
1039
+ pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
1040
+ sample["pixel_values"] = pixel_values
1041
+ sample["control_pixel_values"] = control_pixel_values
1042
+ sample["text"] = name
1043
+ sample["data_type"] = data_type
1044
+ sample["idx"] = idx
1045
+
1046
+ if len(sample) > 0:
1047
+ break
1048
+ except Exception as e:
1049
+ print(e, self.dataset[idx % len(self.dataset)])
1050
+ idx = random.randint(0, self.length-1)
1051
+
1052
+ if self.enable_inpaint and not self.enable_bucket:
1053
+ mask = get_random_mask(pixel_values.size())
1054
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
1055
+ sample["mask_pixel_values"] = mask_pixel_values
1056
+ sample["mask"] = mask
1057
+
1058
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
1059
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
1060
+ sample["clip_pixel_values"] = clip_pixel_values
1061
+
1062
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
1063
+ if (mask == 1).all():
1064
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
1065
+ sample["ref_pixel_values"] = ref_pixel_values
1066
+
1067
+ return sample
videox_fun/data/dataset_image_video_warped.py ADDED
@@ -0,0 +1,1092 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import json
4
+ import math
5
+ import os
6
+ import glob
7
+ import random
8
+ from threading import Thread
9
+ import mediapy as media
10
+ import time
11
+
12
+ import albumentations
13
+ import cv2
14
+ import gc
15
+ import numpy as np
16
+ import torch
17
+ import torchvision.transforms as transforms
18
+ from scipy.special import binom
19
+
20
+ from func_timeout import func_timeout, FunctionTimedOut
21
+ from decord import VideoReader
22
+ from PIL import Image
23
+ from torch.utils.data import BatchSampler, Sampler
24
+ from torch.utils.data.dataset import Dataset
25
+ from contextlib import contextmanager
26
+
27
+ VIDEO_READER_TIMEOUT = 20
28
+
29
+ bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)
30
+
31
+ # codes from https://stackoverflow.com/questions/50731785/create-random-shape-contour-using-matplotlib
32
+ def bezier(points, num=200):
33
+ N = len(points)
34
+ t = np.linspace(0, 1, num=num)
35
+ curve = np.zeros((num, 2))
36
+ for i in range(N):
37
+ curve += np.outer(bernstein(N - 1, i, t), points[i])
38
+ return curve
39
+
40
+ class Segment():
41
+ def __init__(self, p1, p2, angle1, angle2, **kw):
42
+ self.p1 = p1
43
+ self.p2 = p2
44
+ self.angle1 = angle1
45
+ self.angle2 = angle2
46
+ self.numpoints = kw.get("numpoints", 100)
47
+ r = kw.get("r", 0.3)
48
+ d = np.sqrt(np.sum((self.p2-self.p1)**2))
49
+ self.r = r*d
50
+ self.p = np.zeros((4,2))
51
+ self.p[0,:] = self.p1[:]
52
+ self.p[3,:] = self.p2[:]
53
+ self.calc_intermediate_points(self.r)
54
+
55
+ def calc_intermediate_points(self,r):
56
+ self.p[1,:] = self.p1 + np.array(
57
+ [self.r*np.cos(self.angle1), self.r*np.sin(self.angle1)])
58
+ self.p[2,:] = self.p2 + np.array(
59
+ [self.r*np.cos(self.angle2+np.pi), self.r*np.sin(self.angle2+np.pi)])
60
+ self.curve = bezier(self.p,self.numpoints)
61
+
62
+
63
+ def get_curve(points, **kw):
64
+ segments = []
65
+ for i in range(len(points)-1):
66
+ seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
67
+ segments.append(seg)
68
+ curve = np.concatenate([s.curve for s in segments])
69
+ return segments, curve
70
+
71
+
72
+ def ccw_sort(p):
73
+ d = p-np.mean(p,axis=0)
74
+ s = np.arctan2(d[:,0], d[:,1])
75
+ return p[np.argsort(s),:]
76
+
77
+
78
+ def get_bezier_curve(a, rad=0.2, edgy=0):
79
+ """ given an array of points *a*, create a curve through
80
+ those points.
81
+ *rad* is a number between 0 and 1 to steer the distance of
82
+ control points.
83
+ *edgy* is a parameter which controls how "edgy" the curve is,
84
+ edgy=0 is smoothest."""
85
+ p = np.arctan(edgy)/np.pi+.5
86
+ a = ccw_sort(a)
87
+ a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
88
+ d = np.diff(a, axis=0)
89
+ ang = np.arctan2(d[:,1],d[:,0])
90
+ f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
91
+ ang = f(ang)
92
+ ang1 = ang
93
+ ang2 = np.roll(ang,1)
94
+ ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
95
+ ang = np.append(ang, [ang[0]])
96
+ a = np.append(a, np.atleast_2d(ang).T, axis=1)
97
+ s, c = get_curve(a, r=rad, method="var")
98
+ x,y = c.T
99
+ return x,y, a
100
+
101
+
102
+ def get_random_points(n=5, scale=0.8, mindst=None, rec=0):
103
+ """ create n random points in the unit square, which are *mindst*
104
+ apart, then scale them."""
105
+ mindst = mindst or .7/n
106
+ a = np.random.rand(n,2)
107
+ d = np.sqrt(np.sum(np.diff(ccw_sort(a), axis=0), axis=1)**2)
108
+ if np.all(d >= mindst) or rec>=200:
109
+ return a*scale
110
+ else:
111
+ return get_random_points(n=n, scale=scale, mindst=mindst, rec=rec+1)
112
+
113
+
114
+ def fill_mask(shape, x, y, fill_val=255):
115
+ _, _, h, w = shape
116
+ mask = np.zeros((h, w), dtype=np.uint8)
117
+ mask = cv2.fillPoly(mask, [np.array([x, y], np.int32).T], fill_val)
118
+ return mask
119
+
120
+
121
+ def random_shift(x, y, scale_range = [0.2, 0.7], trans_perturb_range=[-0.2, 0.2]):
122
+ w_scale = np.random.uniform(scale_range[0], scale_range[1])
123
+ h_scale = np.random.uniform(scale_range[0], scale_range[1])
124
+ x_trans = np.random.uniform(0., 1. - w_scale)
125
+ y_trans = np.random.uniform(0., 1. - h_scale)
126
+ x_shifted = x * w_scale + x_trans + np.random.uniform(trans_perturb_range[0], trans_perturb_range[1])
127
+ y_shifted = y * h_scale + y_trans + np.random.uniform(trans_perturb_range[0], trans_perturb_range[1])
128
+ return x_shifted, y_shifted
129
+
130
+
131
+ def get_random_shape_mask(
132
+ shape, n_pts_range=[3, 10], rad_range=[0.0, 1.0], edgy_range=[0.0, 0.1], n_keyframes_range=[2, 25],
133
+ random_drop_range=[0.0, 0.2],
134
+ ):
135
+ f, _, h, w = shape
136
+
137
+ n_pts = np.random.randint(n_pts_range[0], n_pts_range[1])
138
+ n_keyframes = np.random.randint(n_keyframes_range[0], n_keyframes_range[1])
139
+ keyframe_interval = f // (n_keyframes - 1)
140
+ keyframe_indices = list(range(0, f, keyframe_interval))
141
+ if len(keyframe_indices) == n_keyframes:
142
+ keyframe_indices[-1] = f - 1
143
+ else:
144
+ keyframe_indices.append(f - 1)
145
+ x_all_frames, y_all_frames = [], []
146
+ for i, keyframe_index in enumerate(keyframe_indices):
147
+ rad = np.random.uniform(rad_range[0], rad_range[1])
148
+ edgy = np.random.uniform(edgy_range[0], edgy_range[1])
149
+ x_kf, y_kf, _ = get_bezier_curve(get_random_points(n=n_pts), rad=rad, edgy=edgy)
150
+ x_kf, y_kf = random_shift(x_kf, y_kf)
151
+ if i == 0:
152
+ x_all_frames.append(x_kf[None])
153
+ y_all_frames.append(y_kf[None])
154
+ else:
155
+ x_interval = np.linspace(x_all_frames[-1][-1], x_kf, keyframe_index - keyframe_indices[i - 1] + 1)
156
+ y_interval = np.linspace(y_all_frames[-1][-1], y_kf, keyframe_index - keyframe_indices[i - 1] + 1)
157
+ x_all_frames.append(x_interval[1:])
158
+ y_all_frames.append(y_interval[1:])
159
+ x_all_frames = np.concatenate(x_all_frames, axis=0)
160
+ y_all_frames = np.concatenate(y_all_frames, axis=0)
161
+
162
+ masks = []
163
+ for x, y in zip(x_all_frames, y_all_frames):
164
+ x = np.round(x * w).astype(np.int32)
165
+ y = np.round(y * h).astype(np.int32)
166
+ mask = fill_mask(shape, x, y)
167
+ masks.append(mask)
168
+ masks = np.stack(masks, axis=0).astype(float) / 255.
169
+
170
+ n_frames_random_drop = int(np.random.uniform(random_drop_range[0], random_drop_range[1]) * f)
171
+ drop_index = np.random.randint(0, f - n_frames_random_drop)
172
+ masks[drop_index:drop_index + n_frames_random_drop] = 0
173
+
174
+ return masks # (f, h, w), <float>[0, 1]
175
+
176
+
177
+ def get_random_mask(shape, mask_type_probs=[0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.8]):
178
+ f, c, h, w = shape
179
+
180
+ if f != 1:
181
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], p=mask_type_probs)
182
+ else:
183
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
184
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
185
+
186
+ if mask_index == 0:
187
+ center_x = torch.randint(0, w, (1,)).item()
188
+ center_y = torch.randint(0, h, (1,)).item()
189
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()
190
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()
191
+
192
+ start_x = max(center_x - block_size_x // 2, 0)
193
+ end_x = min(center_x + block_size_x // 2, w)
194
+ start_y = max(center_y - block_size_y // 2, 0)
195
+ end_y = min(center_y + block_size_y // 2, h)
196
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
197
+ elif mask_index == 1:
198
+ mask[:, :, :, :] = 1
199
+ elif mask_index == 2:
200
+ mask_frame_index = np.random.randint(1, 5)
201
+ mask[mask_frame_index:, :, :, :] = 1
202
+ elif mask_index == 3:
203
+ mask_frame_index = np.random.randint(1, 5)
204
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
205
+ elif mask_index == 4:
206
+ center_x = torch.randint(0, w, (1,)).item()
207
+ center_y = torch.randint(0, h, (1,)).item()
208
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()
209
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()
210
+
211
+ start_x = max(center_x - block_size_x // 2, 0)
212
+ end_x = min(center_x + block_size_x // 2, w)
213
+ start_y = max(center_y - block_size_y // 2, 0)
214
+ end_y = min(center_y + block_size_y // 2, h)
215
+
216
+ mask_frame_before = np.random.randint(0, f // 2)
217
+ mask_frame_after = np.random.randint(f // 2, f)
218
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
219
+ elif mask_index == 5:
220
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
221
+ elif mask_index == 6:
222
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
223
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
224
+
225
+ for i in frames_to_mask:
226
+ block_height = random.randint(1, h // 4)
227
+ block_width = random.randint(1, w // 4)
228
+ top_left_y = random.randint(0, h - block_height)
229
+ top_left_x = random.randint(0, w - block_width)
230
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
231
+ elif mask_index == 7:
232
+ center_x = torch.randint(0, w, (1,)).item()
233
+ center_y = torch.randint(0, h, (1,)).item()
234
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item()
235
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
236
+
237
+ for i in range(h):
238
+ for j in range(w):
239
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
240
+ mask[:, :, i, j] = 1
241
+ elif mask_index == 8:
242
+ center_x = torch.randint(0, w, (1,)).item()
243
+ center_y = torch.randint(0, h, (1,)).item()
244
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
245
+ for i in range(h):
246
+ for j in range(w):
247
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
248
+ mask[:, :, i, j] = 1
249
+ elif mask_index == 9:
250
+ for idx in range(f):
251
+ if np.random.rand() > 0.5:
252
+ mask[idx, :, :, :] = 1
253
+ else:
254
+ num_objs = np.random.randint(1, 4)
255
+ mask_npy = get_random_shape_mask(shape)
256
+ for i in range(num_objs - 1):
257
+ mask_npy += get_random_shape_mask(shape).clip(0, 1)
258
+
259
+ mask = torch.from_numpy(mask_npy).unsqueeze(1)
260
+
261
+ return mask.float()
262
+
263
+
264
+ def get_random_mask_multi(shape, mask_type_probs, range_num_masks=[1, 7]):
265
+ num_masks = np.random.randint(range_num_masks[0], range_num_masks[1])
266
+ masks = None
267
+ for _ in range(num_masks):
268
+ mask = get_random_mask(shape, mask_type_probs)
269
+ if masks is None:
270
+ masks = mask
271
+ else:
272
+ masks = (masks + mask).clip(0, 1)
273
+ return masks
274
+
275
+
276
+ class ImageVideoSampler(BatchSampler):
277
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
278
+
279
+ Args:
280
+ sampler (Sampler): Base sampler.
281
+ dataset (Dataset): Dataset providing data information.
282
+ batch_size (int): Size of mini-batch.
283
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
284
+ its size would be less than ``batch_size``.
285
+ aspect_ratios (dict): The predefined aspect ratios.
286
+ """
287
+
288
+ def __init__(self,
289
+ sampler: Sampler,
290
+ dataset: Dataset,
291
+ batch_size: int,
292
+ drop_last: bool = False
293
+ ) -> None:
294
+ if not isinstance(sampler, Sampler):
295
+ raise TypeError('sampler should be an instance of ``Sampler``, '
296
+ f'but got {sampler}')
297
+ if not isinstance(batch_size, int) or batch_size <= 0:
298
+ raise ValueError('batch_size should be a positive integer value, '
299
+ f'but got batch_size={batch_size}')
300
+ self.sampler = sampler
301
+ self.dataset = dataset
302
+ self.batch_size = batch_size
303
+ self.drop_last = drop_last
304
+
305
+ # buckets for each aspect ratio
306
+ self.bucket = {'image':[], 'video':[], 'video_mask_tuple':[]}
307
+
308
+ def __iter__(self):
309
+ for idx in self.sampler:
310
+ content_type = self.dataset.dataset[idx].get('type', 'image')
311
+ self.bucket[content_type].append(idx)
312
+
313
+ # yield a batch of indices in the same aspect ratio group
314
+ if len(self.bucket['video']) == self.batch_size:
315
+ bucket = self.bucket['video']
316
+ yield bucket[:]
317
+ del bucket[:]
318
+ elif len(self.bucket['video_mask_tuple']) == self.batch_size:
319
+ bucket = self.bucket['video_mask_tuple']
320
+ yield bucket[:]
321
+ del bucket[:]
322
+ elif len(self.bucket['image']) == self.batch_size:
323
+ bucket = self.bucket['image']
324
+ yield bucket[:]
325
+ del bucket[:]
326
+
327
+
328
+ @contextmanager
329
+ def VideoReader_contextmanager(*args, **kwargs):
330
+ vr = VideoReader(*args, **kwargs)
331
+ try:
332
+ yield vr
333
+ finally:
334
+ del vr
335
+ gc.collect()
336
+
337
+
338
+ def get_video_reader_batch(video_reader, batch_index):
339
+ frames = video_reader.get_batch(batch_index).asnumpy()
340
+ return frames
341
+
342
+
343
+ def _read_video_from_dir(video_dir):
344
+ frames = []
345
+ frame_paths = sorted(list(glob.glob(os.path.join(video_dir, '*.png'))))
346
+
347
+ if not frame_paths:
348
+ raise ValueError(f"No PNG files found in directory: {video_dir}")
349
+
350
+ for frame_path in frame_paths:
351
+ frame = media.read_image(frame_path)
352
+ frames.append(frame)
353
+
354
+ if not frames:
355
+ raise ValueError(f"Failed to read any frames from directory: {video_dir}")
356
+
357
+ return np.stack(frames, axis=0)
358
+
359
+
360
+ def resize_frame(frame, target_short_side):
361
+ h, w, _ = frame.shape
362
+ if h < w:
363
+ if target_short_side > h:
364
+ return frame
365
+ new_h = target_short_side
366
+ new_w = int(target_short_side * w / h)
367
+ else:
368
+ if target_short_side > w:
369
+ return frame
370
+ new_w = target_short_side
371
+ new_h = int(target_short_side * h / w)
372
+
373
+ resized_frame = cv2.resize(frame, (new_w, new_h))
374
+ return resized_frame
375
+
376
+
377
+ class ImageVideoDataset(Dataset):
378
+ def __init__(
379
+ self,
380
+ ann_path, data_root=None,
381
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
382
+ image_sample_size=512,
383
+ video_repeat=0,
384
+ text_drop_ratio=0.1,
385
+ enable_bucket=False,
386
+ video_length_drop_start=0.0,
387
+ video_length_drop_end=1.0,
388
+ enable_inpaint=False,
389
+ trimask_zeroout_removal=False,
390
+ use_quadmask=False,
391
+ ablation_binary_mask=False,
392
+ ):
393
+ # Loading annotations from files
394
+ print(f"loading annotations from {ann_path} ...")
395
+ if ann_path.endswith('.csv'):
396
+ with open(ann_path, 'r') as csvfile:
397
+ dataset = list(csv.DictReader(csvfile))
398
+ elif ann_path.endswith('.json'):
399
+ dataset = json.load(open(ann_path))
400
+ else:
401
+ raise ValueError(f"Unsupported annotation file format: {ann_path}. Only .csv and .json files are supported.")
402
+
403
+ self.data_root = data_root
404
+
405
+ # It's used to balance num of images and videos.
406
+ self.dataset = []
407
+ for data in dataset:
408
+ if data.get('type', 'image') != 'video':
409
+ self.dataset.append(data)
410
+ if video_repeat > 0:
411
+ for _ in range(video_repeat):
412
+ for data in dataset:
413
+ if data.get('type', 'image') == 'video':
414
+ self.dataset.append(data)
415
+ del dataset
416
+
417
+ self.length = len(self.dataset)
418
+ print(f"data scale: {self.length}")
419
+ # TODO: enable bucket training
420
+ self.enable_bucket = enable_bucket
421
+ self.text_drop_ratio = text_drop_ratio
422
+ self.enable_inpaint = enable_inpaint
423
+ self.trimask_zeroout_removal = trimask_zeroout_removal
424
+ self.use_quadmask = use_quadmask
425
+ self.ablation_binary_mask = ablation_binary_mask
426
+
427
+ self.video_length_drop_start = video_length_drop_start
428
+ self.video_length_drop_end = video_length_drop_end
429
+
430
+ if self.use_quadmask:
431
+ print(f"[QUADMASK MODE] Using 4-value quadmask: [0, 63, 127, 255]")
432
+ if self.ablation_binary_mask:
433
+ print(f"[ABLATION BINARY MASK] Remapping quadmask to binary: [0,63]→0, [127,255]→127")
434
+ else:
435
+ print(f"[TRIMASK MODE] Using 3-value trimask: [0, 127, 255]")
436
+
437
+ # Video params
438
+ self.video_sample_stride = video_sample_stride
439
+ self.video_sample_n_frames = video_sample_n_frames
440
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
441
+ self.video_transforms = transforms.Compose(
442
+ [
443
+ transforms.Resize(min(self.video_sample_size)),
444
+ transforms.CenterCrop(self.video_sample_size),
445
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
446
+ ]
447
+ )
448
+
449
+ # Image params
450
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
451
+ self.image_transforms = transforms.Compose([
452
+ transforms.Resize(min(self.image_sample_size)),
453
+ transforms.CenterCrop(self.image_sample_size),
454
+ transforms.ToTensor(),
455
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
456
+ ])
457
+
458
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
459
+
460
+ def get_batch(self, idx):
461
+ data_info = self.dataset[idx % len(self.dataset)]
462
+
463
+ if data_info.get('type', 'image') == 'video' and data_info.get('mask_path', None) is None:
464
+ video_id, text = data_info['file_path'], data_info['text']
465
+
466
+ if self.data_root is None:
467
+ video_dir = video_id
468
+ else:
469
+ video_dir = os.path.join(self.data_root, video_id)
470
+
471
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
472
+ min_sample_n_frames = min(
473
+ self.video_sample_n_frames,
474
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
475
+ )
476
+ if min_sample_n_frames == 0:
477
+ raise ValueError(f"No Frames in video.")
478
+
479
+ video_length = int(self.video_length_drop_end * len(video_reader))
480
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
481
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
482
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
483
+
484
+ try:
485
+ sample_args = (video_reader, batch_index)
486
+ pixel_values = func_timeout(
487
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
488
+ )
489
+ resized_frames = []
490
+ for i in range(len(pixel_values)):
491
+ frame = pixel_values[i]
492
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
493
+ resized_frames.append(resized_frame)
494
+ pixel_values = np.array(resized_frames)
495
+ except FunctionTimedOut:
496
+ raise ValueError(f"Read {idx} timeout.")
497
+ except Exception as e:
498
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
499
+
500
+ if not self.enable_bucket:
501
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
502
+ pixel_values = pixel_values / 255.
503
+ del video_reader
504
+ else:
505
+ pixel_values = pixel_values
506
+
507
+ if not self.enable_bucket:
508
+ pixel_values = self.video_transforms(pixel_values)
509
+
510
+ # Random use no text generation
511
+ if random.random() < self.text_drop_ratio:
512
+ text = ''
513
+ return {
514
+ 'pixel_values': pixel_values,
515
+ 'text': text,
516
+ 'data_type': 'video',
517
+ }
518
+ elif data_info.get('type', 'image') == 'video' and data_info.get('mask_path', None) is not None: # video with known mask
519
+ video_path, text = data_info['file_path'], data_info['text']
520
+ mask_video_path = video_path[:-4] + '_mask.mp4'
521
+ with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
522
+ min_sample_n_frames = min(
523
+ self.video_sample_n_frames,
524
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
525
+ )
526
+ if min_sample_n_frames == 0:
527
+ raise ValueError(f"No Frames in video.")
528
+
529
+ video_length = int(self.video_length_drop_end * len(video_reader))
530
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
531
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
532
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
533
+
534
+ try:
535
+ sample_args = (video_reader, batch_index)
536
+ pixel_values = func_timeout(
537
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
538
+ )
539
+ resized_frames = []
540
+ for i in range(len(pixel_values)):
541
+ frame = pixel_values[i]
542
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
543
+ resized_frames.append(resized_frame)
544
+ input_video = np.array(resized_frames)
545
+ except FunctionTimedOut:
546
+ raise ValueError(f"Read {idx} timeout.")
547
+ except Exception as e:
548
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
549
+
550
+ with VideoReader_contextmanager(mask_video_path, num_threads=2) as video_reader:
551
+ try:
552
+ sample_args = (video_reader, batch_index)
553
+ mask_values = func_timeout(
554
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
555
+ )
556
+ resized_frames = []
557
+ for i in range(len(mask_values)):
558
+ frame = mask_values[i]
559
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
560
+ resized_frames.append(resized_frame)
561
+ mask_video = np.array(resized_frames)
562
+ except FunctionTimedOut:
563
+ raise ValueError(f"Read {idx} timeout.")
564
+ except Exception as e:
565
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
566
+
567
+ if len(mask_video.shape) == 3:
568
+ mask_video = mask_video[..., None]
569
+ if mask_video.shape[-1] == 3:
570
+ mask_video = mask_video[..., :1]
571
+ if len(mask_video.shape) != 4:
572
+ raise ValueError(f"mask_video shape is {mask_video.shape}.")
573
+
574
+ text = data_info['text']
575
+ if not self.enable_bucket:
576
+ input_video = torch.from_numpy(input_video).permute(0, 3, 1, 2).contiguous() / 255.
577
+ mask_video = torch.from_numpy(mask_video).permute(0, 3, 1, 2).contiguous() / 255.
578
+
579
+ pixel_values = torch.cat([input_video, mask_video], dim=1)
580
+ pixel_values = self.video_transforms(pixel_values)
581
+ input_video = pixel_values[:, :3]
582
+ mask_video = pixel_values[:, 3:]
583
+
584
+ # Random use no text generation
585
+ if random.random() < self.text_drop_ratio:
586
+ text = ''
587
+
588
+ return {
589
+ 'pixel_values': input_video,
590
+ 'mask': mask_video,
591
+ 'text': text,
592
+ 'data_type': 'video',
593
+ }
594
+
595
+ elif data_info.get('type', 'image') == 'video_mask_tuple': # object effect removal
596
+ sample_dir = data_info['file_path'] if self.data_root is None else os.path.join(self.data_root, data_info['file_path'])
597
+ try:
598
+ if os.path.exists(os.path.join(sample_dir, 'rgb_full.mp4')):
599
+ input_video_path = os.path.join(sample_dir, 'rgb_full.mp4')
600
+ target_video_path = os.path.join(sample_dir, 'rgb_removed.mp4')
601
+ mask_video_path = os.path.join(sample_dir, 'mask.mp4')
602
+ depth_video_path = os.path.join(sample_dir, 'depth_removed.mp4')
603
+
604
+ input_video = media.read_video(input_video_path)
605
+ target_video = media.read_video(target_video_path)
606
+ mask_video = media.read_video(mask_video_path)
607
+
608
+ # Load depth map if it exists
609
+ depth_video = None
610
+ if os.path.exists(depth_video_path):
611
+ depth_video = media.read_video(depth_video_path)
612
+
613
+ else:
614
+ input_video_path = os.path.join(sample_dir, 'input')
615
+ target_video_path = os.path.join(sample_dir, 'bg')
616
+ mask_video_path = os.path.join(sample_dir, 'trimask')
617
+
618
+ input_video = _read_video_from_dir(input_video_path)
619
+ target_video = _read_video_from_dir(target_video_path)
620
+ mask_video = _read_video_from_dir(mask_video_path)
621
+
622
+ # Initialize depth_video as None for this path
623
+ depth_video = None
624
+ except Exception as e:
625
+ print(f"Error loading video_mask_tuple from {sample_dir}: {e}")
626
+ import traceback
627
+ traceback.print_exc()
628
+ raise
629
+
630
+ mask_video = 255 - mask_video # will be flipped again in when feeding to model
631
+
632
+ if len(mask_video.shape) == 3:
633
+ mask_video = mask_video[..., None]
634
+ if mask_video.shape[-1] == 3:
635
+ mask_video = mask_video[..., :1]
636
+ min_sample_n_frames = min(
637
+ self.video_sample_n_frames,
638
+ int(len(input_video) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
639
+ )
640
+ video_length = int(self.video_length_drop_end * len(input_video))
641
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
642
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
643
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
644
+ input_video = input_video[batch_index]
645
+ target_video = target_video[batch_index]
646
+ mask_video = mask_video[batch_index]
647
+ if depth_video is not None:
648
+ depth_video = depth_video[batch_index]
649
+
650
+ resized_inputs = []
651
+ resized_targets = []
652
+ resized_masks = []
653
+ resized_depths = []
654
+ for i in range(len(input_video)):
655
+ resized_input = resize_frame(input_video[i], self.larger_side_of_image_and_video)
656
+ resized_target = resize_frame(target_video[i], self.larger_side_of_image_and_video)
657
+ resized_mask = resize_frame(mask_video[i], self.larger_side_of_image_and_video)
658
+
659
+ # Apply mask quantization based on mode
660
+ if self.ablation_binary_mask:
661
+ # Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127]
662
+ # Map 0 and 63 → 0
663
+ # Map 127 and 255 → 127
664
+ resized_mask = np.where(resized_mask <= 95, 0, resized_mask)
665
+ resized_mask = np.where(resized_mask > 95, 127, resized_mask)
666
+ elif self.use_quadmask:
667
+ # Quadmask mode: preserve 4 values [0, 63, 127, 255]
668
+ # Quantize to nearest quadmask value for robustness
669
+ resized_mask = np.where(resized_mask <= 31, 0, resized_mask)
670
+ resized_mask = np.where(np.logical_and(resized_mask > 31, resized_mask <= 95), 63, resized_mask)
671
+ resized_mask = np.where(np.logical_and(resized_mask > 95, resized_mask <= 191), 127, resized_mask)
672
+ resized_mask = np.where(resized_mask > 191, 255, resized_mask)
673
+ else:
674
+ # Trimask mode: 3 values [0, 127, 255]
675
+ resized_mask = np.where(np.logical_and(resized_mask > 63, resized_mask < 192), 127, resized_mask)
676
+ resized_mask = np.where(resized_mask >= 192, 255, resized_mask)
677
+ resized_mask = np.where(resized_mask <= 63, 0, resized_mask)
678
+
679
+ resized_inputs.append(resized_input)
680
+ resized_targets.append(resized_target)
681
+ resized_masks.append(resized_mask)
682
+
683
+ if depth_video is not None:
684
+ resized_depth = resize_frame(depth_video[i], self.larger_side_of_image_and_video)
685
+ resized_depths.append(resized_depth)
686
+
687
+ input_video = np.array(resized_inputs)
688
+ target_video = np.array(resized_targets)
689
+ mask_video = np.array(resized_masks)
690
+ if depth_video is not None:
691
+ depth_video = np.array(resized_depths)
692
+
693
+ if len(mask_video.shape) == 3:
694
+ mask_video = mask_video[..., None]
695
+ if mask_video.shape[-1] == 3:
696
+ mask_video = mask_video[..., :1]
697
+ if len(mask_video.shape) != 4:
698
+ raise ValueError(f"mask_video shape is {mask_video.shape}.")
699
+
700
+ text = data_info['text']
701
+ print(f"DEBUG DATASET: Converting to tensors (enable_bucket={self.enable_bucket})...")
702
+ if not self.enable_bucket:
703
+ print(f"DEBUG DATASET: Converting input_video to tensor...")
704
+ input_video = torch.from_numpy(input_video).permute(0, 3, 1, 2).contiguous() / 255.
705
+ print(f"DEBUG DATASET: Converting target_video to tensor...")
706
+ target_video = torch.from_numpy(target_video).permute(0, 3, 1, 2).contiguous() / 255.
707
+ print(f"DEBUG DATASET: Converting mask_video to tensor...")
708
+ mask_video = torch.from_numpy(mask_video).permute(0, 3, 1, 2).contiguous() / 255.
709
+
710
+ # Process depth video if available
711
+ if depth_video is not None:
712
+ print(f"DEBUG DATASET: Processing depth_video...")
713
+ # IMPORTANT: Copy depth_video to ensure it's not memory-mapped
714
+ # Memory-mapped files can cause bus errors on GPU transfer
715
+ print(f"DEBUG DATASET: Copying depth_video to ensure not memory-mapped...")
716
+ depth_video = np.array(depth_video, copy=True)
717
+ print(f"DEBUG DATASET: depth_video copied, shape={depth_video.shape}")
718
+
719
+ # Ensure depth has correct shape
720
+ if len(depth_video.shape) == 3:
721
+ depth_video = depth_video[..., None]
722
+ if depth_video.shape[-1] == 3:
723
+ # Convert to grayscale if RGB
724
+ print(f"DEBUG DATASET: Converting depth to grayscale...")
725
+ depth_video = depth_video.mean(axis=-1, keepdims=True)
726
+ # Convert to tensor [F, 1, H, W] and normalize to [0, 1]
727
+ print(f"DEBUG DATASET: Converting depth to tensor...")
728
+ depth_video = torch.from_numpy(depth_video).permute(0, 3, 1, 2).contiguous().float() / 255.
729
+ # Ensure tensor is contiguous and owned
730
+ print(f"DEBUG DATASET: Cloning depth tensor...")
731
+ depth_video = depth_video.clone().contiguous()
732
+ print(f"DEBUG DATASET: depth_video final shape: {depth_video.shape}, is_contiguous: {depth_video.is_contiguous()}")
733
+
734
+ # Apply transforms to each video separately (they expect 3 channels)
735
+ print(f"DEBUG DATASET: Applying video transforms...")
736
+ input_video = self.video_transforms(input_video)
737
+ target_video = self.video_transforms(target_video)
738
+ # Don't normalize mask since it's single channel
739
+ print(f"DEBUG DATASET: Normalizing mask_video...")
740
+ mask_video = mask_video * 2.0 - 1.0 # Scale to [-1, 1] like other channels
741
+ print(f"DEBUG DATASET: All tensors ready (non-bucket mode)")
742
+
743
+ else:
744
+ # For bucket mode, keep as numpy until collate
745
+ # Collate function expects [0, 255] range and will normalize
746
+ print(f"DEBUG DATASET: Bucket mode - keeping as numpy in [0, 255] range...")
747
+ print(f"DEBUG DATASET: All numpy arrays ready (bucket mode)")
748
+
749
+ # Load warped noise - REQUIRED if specified in dataset
750
+ warped_noise = None
751
+ if 'warped_noise_path' in data_info:
752
+ warped_noise_dir = data_info['warped_noise_path'] if self.data_root is None else os.path.join(self.data_root, data_info['warped_noise_path'])
753
+ noise_path = os.path.join(warped_noise_dir, 'noises.npy')
754
+
755
+ if not os.path.exists(noise_path):
756
+ raise FileNotFoundError(
757
+ f"Warped noise path specified in dataset but file not found: {noise_path}\n"
758
+ f"Make sure you've generated warped noise for all videos in the dataset."
759
+ )
760
+
761
+ try:
762
+ warped_noise = np.load(noise_path) # Shape: (T, C, H, W) in float16
763
+ warped_noise = torch.from_numpy(warped_noise).float() # Convert to torch tensor
764
+ except Exception as e:
765
+ raise RuntimeError(
766
+ f"Failed to load warped noise from {noise_path}: {e}\n"
767
+ f"The noise file may be corrupted. Try regenerating it."
768
+ )
769
+
770
+ # Random use no text generation
771
+ if random.random() < self.text_drop_ratio:
772
+ text = ''
773
+
774
+ if self.trimask_zeroout_removal:
775
+ input_video = input_video * np.where(mask_video > 200, 0, 1).astype(input_video.dtype)
776
+
777
+ result = {
778
+ 'pixel_values': target_video,
779
+ 'input_condition': input_video,
780
+ 'mask': mask_video,
781
+ 'text': text,
782
+ 'data_type': 'video_mask_tuple',
783
+ }
784
+
785
+ # Add depth maps if available
786
+ if depth_video is not None:
787
+ result['depth_maps'] = depth_video
788
+
789
+ # Add warped noise to batch if available
790
+ if warped_noise is not None:
791
+ result['warped_noise'] = warped_noise
792
+
793
+ return result
794
+
795
+ else:
796
+ image_path, text = data_info['file_path'], data_info['text']
797
+ if self.data_root is not None:
798
+ image_path = os.path.join(self.data_root, image_path)
799
+ image = Image.open(image_path).convert('RGB')
800
+ if not self.enable_bucket:
801
+ image = self.image_transforms(image).unsqueeze(0)
802
+ else:
803
+ image = np.expand_dims(np.array(image), 0)
804
+ if random.random() < self.text_drop_ratio:
805
+ text = ''
806
+ return {
807
+ 'pixel_values': image,
808
+ 'text': text,
809
+ 'data_type': 'image',
810
+ }
811
+
812
+ def __len__(self):
813
+ return self.length
814
+
815
+ def __getitem__(self, idx):
816
+ data_info = self.dataset[idx % len(self.dataset)]
817
+ data_type = data_info.get('type', 'image')
818
+ while True:
819
+ sample = {}
820
+ try:
821
+ data_info_local = self.dataset[idx % len(self.dataset)]
822
+ data_type_local = data_info_local.get('type', 'image')
823
+ if data_type_local != data_type:
824
+ raise ValueError("data_type_local != data_type")
825
+
826
+ sample = self.get_batch(idx)
827
+ sample["idx"] = idx
828
+
829
+ if len(sample) > 0:
830
+ break
831
+ except Exception as e:
832
+ import traceback
833
+ print(f"Error loading sample at index {idx}:")
834
+ print(f"Data info: {self.dataset[idx % len(self.dataset)]}")
835
+ print(f"Error: {e}")
836
+ traceback.print_exc()
837
+ idx = random.randint(0, self.length-1)
838
+
839
+ if self.enable_inpaint and not self.enable_bucket:
840
+ if "mask" not in sample:
841
+ mask = get_random_mask_multi(sample["pixel_values"].size())
842
+ sample["mask"] = mask
843
+ else:
844
+ mask = sample["mask"]
845
+
846
+ if "input_condition" in sample:
847
+ mask_pixel_values = sample["input_condition"]
848
+ else:
849
+ mask_pixel_values = sample["pixel_values"]
850
+ mask_pixel_values = mask_pixel_values * (1 - mask) + torch.ones_like(mask_pixel_values) * -1 * mask
851
+
852
+ sample["mask_pixel_values"] = mask_pixel_values
853
+
854
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
855
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
856
+ sample["clip_pixel_values"] = clip_pixel_values
857
+
858
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
859
+ if (mask == 1).all():
860
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
861
+ sample["ref_pixel_values"] = ref_pixel_values
862
+
863
+ return sample
864
+
865
+
866
+ class ImageVideoControlDataset(Dataset):
867
+ def __init__(
868
+ self,
869
+ ann_path, data_root=None,
870
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
871
+ image_sample_size=512,
872
+ video_repeat=0,
873
+ text_drop_ratio=0.1,
874
+ enable_bucket=False,
875
+ video_length_drop_start=0.0,
876
+ video_length_drop_end=1.0,
877
+ enable_inpaint=False,
878
+ ):
879
+ # Loading annotations from files
880
+ print(f"loading annotations from {ann_path} ...")
881
+ if ann_path.endswith('.csv'):
882
+ with open(ann_path, 'r') as csvfile:
883
+ dataset = list(csv.DictReader(csvfile))
884
+ elif ann_path.endswith('.json'):
885
+ dataset = json.load(open(ann_path))
886
+ else:
887
+ raise ValueError(f"Unsupported annotation file format: {ann_path}. Only .csv and .json files are supported.")
888
+
889
+ self.data_root = data_root
890
+
891
+ # It's used to balance num of images and videos.
892
+ self.dataset = []
893
+ for data in dataset:
894
+ if data.get('type', 'image') != 'video':
895
+ self.dataset.append(data)
896
+ if video_repeat > 0:
897
+ for _ in range(video_repeat):
898
+ for data in dataset:
899
+ if data.get('type', 'image') == 'video':
900
+ self.dataset.append(data)
901
+ del dataset
902
+
903
+ self.length = len(self.dataset)
904
+ print(f"data scale: {self.length}")
905
+ # TODO: enable bucket training
906
+ self.enable_bucket = enable_bucket
907
+ self.text_drop_ratio = text_drop_ratio
908
+ self.enable_inpaint = enable_inpaint
909
+
910
+ self.video_length_drop_start = video_length_drop_start
911
+ self.video_length_drop_end = video_length_drop_end
912
+
913
+ # Video params
914
+ self.video_sample_stride = video_sample_stride
915
+ self.video_sample_n_frames = video_sample_n_frames
916
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
917
+ self.video_transforms = transforms.Compose(
918
+ [
919
+ transforms.Resize(min(self.video_sample_size)),
920
+ transforms.CenterCrop(self.video_sample_size),
921
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
922
+ ]
923
+ )
924
+
925
+ # Image params
926
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
927
+ self.image_transforms = transforms.Compose([
928
+ transforms.Resize(min(self.image_sample_size)),
929
+ transforms.CenterCrop(self.image_sample_size),
930
+ transforms.ToTensor(),
931
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
932
+ ])
933
+
934
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
935
+
936
+ def get_batch(self, idx):
937
+ data_info = self.dataset[idx % len(self.dataset)]
938
+ video_id, text = data_info['file_path'], data_info['text']
939
+
940
+ if data_info.get('type', 'image')=='video':
941
+ if self.data_root is None:
942
+ video_dir = video_id
943
+ else:
944
+ video_dir = os.path.join(self.data_root, video_id)
945
+
946
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
947
+ min_sample_n_frames = min(
948
+ self.video_sample_n_frames,
949
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
950
+ )
951
+ if min_sample_n_frames == 0:
952
+ raise ValueError(f"No Frames in video.")
953
+
954
+ video_length = int(self.video_length_drop_end * len(video_reader))
955
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
956
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
957
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
958
+
959
+ try:
960
+ sample_args = (video_reader, batch_index)
961
+ pixel_values = func_timeout(
962
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
963
+ )
964
+ resized_frames = []
965
+ for i in range(len(pixel_values)):
966
+ frame = pixel_values[i]
967
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
968
+ resized_frames.append(resized_frame)
969
+ pixel_values = np.array(resized_frames)
970
+ except FunctionTimedOut:
971
+ raise ValueError(f"Read {idx} timeout.")
972
+ except Exception as e:
973
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
974
+
975
+ if not self.enable_bucket:
976
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
977
+ pixel_values = pixel_values / 255.
978
+ del video_reader
979
+ else:
980
+ pixel_values = pixel_values
981
+
982
+ if not self.enable_bucket:
983
+ pixel_values = self.video_transforms(pixel_values)
984
+
985
+ # Random use no text generation
986
+ if random.random() < self.text_drop_ratio:
987
+ text = ''
988
+
989
+ control_video_id = data_info['control_file_path']
990
+
991
+ if self.data_root is None:
992
+ control_video_id = control_video_id
993
+ else:
994
+ control_video_id = os.path.join(self.data_root, control_video_id)
995
+
996
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
997
+ try:
998
+ sample_args = (control_video_reader, batch_index)
999
+ control_pixel_values = func_timeout(
1000
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
1001
+ )
1002
+ resized_frames = []
1003
+ for i in range(len(control_pixel_values)):
1004
+ frame = control_pixel_values[i]
1005
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
1006
+ resized_frames.append(resized_frame)
1007
+ control_pixel_values = np.array(resized_frames)
1008
+ except FunctionTimedOut:
1009
+ raise ValueError(f"Read {idx} timeout.")
1010
+ except Exception as e:
1011
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
1012
+
1013
+ if not self.enable_bucket:
1014
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
1015
+ control_pixel_values = control_pixel_values / 255.
1016
+ del control_video_reader
1017
+ else:
1018
+ control_pixel_values = control_pixel_values
1019
+
1020
+ if not self.enable_bucket:
1021
+ control_pixel_values = self.video_transforms(control_pixel_values)
1022
+ return pixel_values, control_pixel_values, text, "video"
1023
+ else:
1024
+ image_path, text = data_info['file_path'], data_info['text']
1025
+ if self.data_root is not None:
1026
+ image_path = os.path.join(self.data_root, image_path)
1027
+ image = Image.open(image_path).convert('RGB')
1028
+ if not self.enable_bucket:
1029
+ image = self.image_transforms(image).unsqueeze(0)
1030
+ else:
1031
+ image = np.expand_dims(np.array(image), 0)
1032
+
1033
+ if random.random() < self.text_drop_ratio:
1034
+ text = ''
1035
+
1036
+ control_image_id = data_info['control_file_path']
1037
+
1038
+ if self.data_root is None:
1039
+ control_image_id = control_image_id
1040
+ else:
1041
+ control_image_id = os.path.join(self.data_root, control_image_id)
1042
+
1043
+ control_image = Image.open(control_image_id).convert('RGB')
1044
+ if not self.enable_bucket:
1045
+ control_image = self.image_transforms(control_image).unsqueeze(0)
1046
+ else:
1047
+ control_image = np.expand_dims(np.array(control_image), 0)
1048
+ return image, control_image, text, 'image'
1049
+
1050
+ def __len__(self):
1051
+ return self.length
1052
+
1053
+ def __getitem__(self, idx):
1054
+ data_info = self.dataset[idx % len(self.dataset)]
1055
+ data_type = data_info.get('type', 'image')
1056
+ while True:
1057
+ sample = {}
1058
+ try:
1059
+ data_info_local = self.dataset[idx % len(self.dataset)]
1060
+ data_type_local = data_info_local.get('type', 'image')
1061
+ if data_type_local != data_type:
1062
+ raise ValueError("data_type_local != data_type")
1063
+
1064
+ pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
1065
+ sample["pixel_values"] = pixel_values
1066
+ sample["control_pixel_values"] = control_pixel_values
1067
+ sample["text"] = name
1068
+ sample["data_type"] = data_type
1069
+ sample["idx"] = idx
1070
+
1071
+ if len(sample) > 0:
1072
+ break
1073
+ except Exception as e:
1074
+ print(e, self.dataset[idx % len(self.dataset)])
1075
+ idx = random.randint(0, self.length-1)
1076
+
1077
+ if self.enable_inpaint and not self.enable_bucket:
1078
+ mask = get_random_mask(pixel_values.size())
1079
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
1080
+ sample["mask_pixel_values"] = mask_pixel_values
1081
+ sample["mask"] = mask
1082
+
1083
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
1084
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
1085
+ sample["clip_pixel_values"] = clip_pixel_values
1086
+
1087
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
1088
+ if (mask == 1).all():
1089
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
1090
+ sample["ref_pixel_values"] = ref_pixel_values
1091
+
1092
+ return sample
videox_fun/data/dataset_video.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ from decord import VideoReader
17
+ from einops import rearrange
18
+ from func_timeout import FunctionTimedOut, func_timeout
19
+ from PIL import Image
20
+ from torch.utils.data import BatchSampler, Sampler
21
+ from torch.utils.data.dataset import Dataset
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ mask_index = np.random.randint(0, 4)
29
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
30
+ if mask_index == 0:
31
+ mask[1:, :, :, :] = 1
32
+ elif mask_index == 1:
33
+ mask_frame_index = 1
34
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
35
+ elif mask_index == 2:
36
+ center_x = torch.randint(0, w, (1,)).item()
37
+ center_y = torch.randint(0, h, (1,)).item()
38
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
39
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
40
+
41
+ start_x = max(center_x - block_size_x // 2, 0)
42
+ end_x = min(center_x + block_size_x // 2, w)
43
+ start_y = max(center_y - block_size_y // 2, 0)
44
+ end_y = min(center_y + block_size_y // 2, h)
45
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
46
+ elif mask_index == 3:
47
+ center_x = torch.randint(0, w, (1,)).item()
48
+ center_y = torch.randint(0, h, (1,)).item()
49
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
50
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
51
+
52
+ start_x = max(center_x - block_size_x // 2, 0)
53
+ end_x = min(center_x + block_size_x // 2, w)
54
+ start_y = max(center_y - block_size_y // 2, 0)
55
+ end_y = min(center_y + block_size_y // 2, h)
56
+
57
+ mask_frame_before = np.random.randint(0, f // 2)
58
+ mask_frame_after = np.random.randint(f // 2, f)
59
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
60
+ else:
61
+ raise ValueError(f"The mask_index {mask_index} is not define")
62
+ return mask
63
+
64
+
65
+ @contextmanager
66
+ def VideoReader_contextmanager(*args, **kwargs):
67
+ vr = VideoReader(*args, **kwargs)
68
+ try:
69
+ yield vr
70
+ finally:
71
+ del vr
72
+ gc.collect()
73
+
74
+
75
+ def get_video_reader_batch(video_reader, batch_index):
76
+ frames = video_reader.get_batch(batch_index).asnumpy()
77
+ return frames
78
+
79
+
80
+ class WebVid10M(Dataset):
81
+ def __init__(
82
+ self,
83
+ csv_path, video_folder,
84
+ sample_size=256, sample_stride=4, sample_n_frames=16,
85
+ enable_bucket=False, enable_inpaint=False, is_image=False,
86
+ ):
87
+ print(f"loading annotations from {csv_path} ...")
88
+ with open(csv_path, 'r') as csvfile:
89
+ self.dataset = list(csv.DictReader(csvfile))
90
+ self.length = len(self.dataset)
91
+ print(f"data scale: {self.length}")
92
+
93
+ self.video_folder = video_folder
94
+ self.sample_stride = sample_stride
95
+ self.sample_n_frames = sample_n_frames
96
+ self.enable_bucket = enable_bucket
97
+ self.enable_inpaint = enable_inpaint
98
+ self.is_image = is_image
99
+
100
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
101
+ self.pixel_transforms = transforms.Compose([
102
+ transforms.Resize(sample_size[0]),
103
+ transforms.CenterCrop(sample_size),
104
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
105
+ ])
106
+
107
+ def get_batch(self, idx):
108
+ video_dict = self.dataset[idx]
109
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
110
+
111
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
112
+ video_reader = VideoReader(video_dir)
113
+ video_length = len(video_reader)
114
+
115
+ if not self.is_image:
116
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
117
+ start_idx = random.randint(0, video_length - clip_length)
118
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
119
+ else:
120
+ batch_index = [random.randint(0, video_length - 1)]
121
+
122
+ if not self.enable_bucket:
123
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
124
+ pixel_values = pixel_values / 255.
125
+ del video_reader
126
+ else:
127
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
128
+
129
+ if self.is_image:
130
+ pixel_values = pixel_values[0]
131
+ return pixel_values, name
132
+
133
+ def __len__(self):
134
+ return self.length
135
+
136
+ def __getitem__(self, idx):
137
+ while True:
138
+ try:
139
+ pixel_values, name = self.get_batch(idx)
140
+ break
141
+
142
+ except Exception as e:
143
+ print("Error info:", e)
144
+ idx = random.randint(0, self.length-1)
145
+
146
+ if not self.enable_bucket:
147
+ pixel_values = self.pixel_transforms(pixel_values)
148
+ if self.enable_inpaint:
149
+ mask = get_random_mask(pixel_values.size())
150
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
151
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
152
+ else:
153
+ sample = dict(pixel_values=pixel_values, text=name)
154
+ return sample
155
+
156
+
157
+ class VideoDataset(Dataset):
158
+ def __init__(
159
+ self,
160
+ json_path, video_folder=None,
161
+ sample_size=256, sample_stride=4, sample_n_frames=16,
162
+ enable_bucket=False, enable_inpaint=False
163
+ ):
164
+ print(f"loading annotations from {json_path} ...")
165
+ self.dataset = json.load(open(json_path, 'r'))
166
+ self.length = len(self.dataset)
167
+ print(f"data scale: {self.length}")
168
+
169
+ self.video_folder = video_folder
170
+ self.sample_stride = sample_stride
171
+ self.sample_n_frames = sample_n_frames
172
+ self.enable_bucket = enable_bucket
173
+ self.enable_inpaint = enable_inpaint
174
+
175
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
176
+ self.pixel_transforms = transforms.Compose(
177
+ [
178
+ transforms.Resize(sample_size[0]),
179
+ transforms.CenterCrop(sample_size),
180
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
181
+ ]
182
+ )
183
+
184
+ def get_batch(self, idx):
185
+ video_dict = self.dataset[idx]
186
+ video_id, name = video_dict['file_path'], video_dict['text']
187
+
188
+ if self.video_folder is None:
189
+ video_dir = video_id
190
+ else:
191
+ video_dir = os.path.join(self.video_folder, video_id)
192
+
193
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
194
+ video_length = len(video_reader)
195
+
196
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
197
+ start_idx = random.randint(0, video_length - clip_length)
198
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
199
+
200
+ try:
201
+ sample_args = (video_reader, batch_index)
202
+ pixel_values = func_timeout(
203
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
204
+ )
205
+ except FunctionTimedOut:
206
+ raise ValueError(f"Read {idx} timeout.")
207
+ except Exception as e:
208
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
209
+
210
+ if not self.enable_bucket:
211
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
212
+ pixel_values = pixel_values / 255.
213
+ del video_reader
214
+ else:
215
+ pixel_values = pixel_values
216
+
217
+ return pixel_values, name
218
+
219
+ def __len__(self):
220
+ return self.length
221
+
222
+ def __getitem__(self, idx):
223
+ while True:
224
+ try:
225
+ pixel_values, name = self.get_batch(idx)
226
+ break
227
+
228
+ except Exception as e:
229
+ print("Error info:", e)
230
+ idx = random.randint(0, self.length-1)
231
+
232
+ if not self.enable_bucket:
233
+ pixel_values = self.pixel_transforms(pixel_values)
234
+ if self.enable_inpaint:
235
+ mask = get_random_mask(pixel_values.size())
236
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
237
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
238
+ else:
239
+ sample = dict(pixel_values=pixel_values, text=name)
240
+ return sample
241
+
242
+
243
+ if __name__ == "__main__":
244
+ if 1:
245
+ dataset = VideoDataset(
246
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
247
+ sample_size=256,
248
+ sample_stride=4, sample_n_frames=16,
249
+ )
250
+
251
+ if 0:
252
+ dataset = WebVid10M(
253
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
254
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
255
+ sample_size=256,
256
+ sample_stride=4, sample_n_frames=16,
257
+ is_image=False,
258
+ )
259
+
260
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
261
+ for idx, batch in enumerate(dataloader):
262
+ print(batch["pixel_values"].shape, len(batch["text"]))
videox_fun/dist/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ try:
5
+ import xfuser
6
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
7
+ get_sequence_parallel_world_size,
8
+ get_sp_group, get_world_group,
9
+ init_distributed_environment,
10
+ initialize_model_parallel)
11
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
12
+ except Exception as ex:
13
+ get_sequence_parallel_world_size = None
14
+ get_sequence_parallel_rank = None
15
+ xFuserLongContextAttention = None
16
+ get_sp_group = None
17
+ get_world_group = None
18
+ init_distributed_environment = None
19
+ initialize_model_parallel = None
20
+
21
+ def set_multi_gpus_devices(ulysses_degree, ring_degree):
22
+ if ulysses_degree > 1 or ring_degree > 1:
23
+ if get_sp_group is None:
24
+ raise RuntimeError("xfuser is not installed.")
25
+ dist.init_process_group("nccl")
26
+ print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
27
+ ulysses_degree, ring_degree, dist.get_rank(),
28
+ dist.get_world_size()))
29
+ assert dist.get_world_size() == ring_degree * ulysses_degree, \
30
+ "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
31
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
32
+ initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
33
+ ring_degree=ring_degree,
34
+ ulysses_degree=ulysses_degree)
35
+ # device = torch.device("cuda:%d" % dist.get_rank())
36
+ device = torch.device(f"cuda:{get_world_group().local_rank}")
37
+ print('rank=%d device=%s' % (get_world_group().rank, str(device)))
38
+ else:
39
+ device = "cuda"
40
+ return device
videox_fun/dist/cogvideox_xfuser.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention import Attention
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+
8
+ try:
9
+ import xfuser
10
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
11
+ get_sequence_parallel_world_size,
12
+ get_sp_group,
13
+ init_distributed_environment,
14
+ initialize_model_parallel)
15
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
16
+ except Exception as ex:
17
+ get_sequence_parallel_world_size = None
18
+ get_sequence_parallel_rank = None
19
+ xFuserLongContextAttention = None
20
+ get_sp_group = None
21
+ init_distributed_environment = None
22
+ initialize_model_parallel = None
23
+
24
+ class CogVideoXMultiGPUsAttnProcessor2_0:
25
+ r"""
26
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
27
+ query and key vectors, but does not include spatial normalization.
28
+ """
29
+
30
+ def __init__(self):
31
+ if xFuserLongContextAttention is not None:
32
+ try:
33
+ self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
34
+ except Exception:
35
+ self.hybrid_seq_parallel_attn = None
36
+ else:
37
+ self.hybrid_seq_parallel_attn = None
38
+ if not hasattr(F, "scaled_dot_product_attention"):
39
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
40
+
41
+ def __call__(
42
+ self,
43
+ attn: Attention,
44
+ hidden_states: torch.Tensor,
45
+ encoder_hidden_states: torch.Tensor,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ image_rotary_emb: Optional[torch.Tensor] = None,
48
+ ) -> torch.Tensor:
49
+ text_seq_length = encoder_hidden_states.size(1)
50
+
51
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
52
+
53
+ batch_size, sequence_length, _ = (
54
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
55
+ )
56
+
57
+ if attention_mask is not None:
58
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
59
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
60
+
61
+ query = attn.to_q(hidden_states)
62
+ key = attn.to_k(hidden_states)
63
+ value = attn.to_v(hidden_states)
64
+
65
+ inner_dim = key.shape[-1]
66
+ head_dim = inner_dim // attn.heads
67
+
68
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
69
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
70
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
71
+
72
+ if attn.norm_q is not None:
73
+ query = attn.norm_q(query)
74
+ if attn.norm_k is not None:
75
+ key = attn.norm_k(key)
76
+
77
+ # Apply RoPE if needed
78
+ if image_rotary_emb is not None:
79
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
80
+ if not attn.is_cross_attention:
81
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
82
+
83
+ if self.hybrid_seq_parallel_attn is None:
84
+ hidden_states = F.scaled_dot_product_attention(
85
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
86
+ )
87
+ hidden_states = hidden_states
88
+ else:
89
+ img_q = query[:, :, text_seq_length:].transpose(1, 2)
90
+ txt_q = query[:, :, :text_seq_length].transpose(1, 2)
91
+ img_k = key[:, :, text_seq_length:].transpose(1, 2)
92
+ txt_k = key[:, :, :text_seq_length].transpose(1, 2)
93
+ img_v = value[:, :, text_seq_length:].transpose(1, 2)
94
+ txt_v = value[:, :, :text_seq_length].transpose(1, 2)
95
+
96
+ hidden_states = self.hybrid_seq_parallel_attn(
97
+ None,
98
+ img_q, img_k, img_v, dropout_p=0.0, causal=False,
99
+ joint_tensor_query=txt_q,
100
+ joint_tensor_key=txt_k,
101
+ joint_tensor_value=txt_v,
102
+ joint_strategy='front',
103
+ ).transpose(1, 2)
104
+
105
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
106
+
107
+ # linear proj
108
+ hidden_states = attn.to_out[0](hidden_states)
109
+ # dropout
110
+ hidden_states = attn.to_out[1](hidden_states)
111
+
112
+ encoder_hidden_states, hidden_states = hidden_states.split(
113
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
114
+ )
115
+ return hidden_states, encoder_hidden_states
116
+
videox_fun/dist/wan_xfuser.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.cuda.amp as amp
3
+
4
+ try:
5
+ import xfuser
6
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
7
+ get_sequence_parallel_world_size,
8
+ get_sp_group,
9
+ init_distributed_environment,
10
+ initialize_model_parallel)
11
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
12
+ except Exception as ex:
13
+ get_sequence_parallel_world_size = None
14
+ get_sequence_parallel_rank = None
15
+ xFuserLongContextAttention = None
16
+ get_sp_group = None
17
+ init_distributed_environment = None
18
+ initialize_model_parallel = None
19
+
20
+ def pad_freqs(original_tensor, target_len):
21
+ seq_len, s1, s2 = original_tensor.shape
22
+ pad_size = target_len - seq_len
23
+ padding_tensor = torch.ones(
24
+ pad_size,
25
+ s1,
26
+ s2,
27
+ dtype=original_tensor.dtype,
28
+ device=original_tensor.device)
29
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
30
+ return padded_tensor
31
+
32
+ @amp.autocast(enabled=False)
33
+ def rope_apply(x, grid_sizes, freqs):
34
+ """
35
+ x: [B, L, N, C].
36
+ grid_sizes: [B, 3].
37
+ freqs: [M, C // 2].
38
+ """
39
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
40
+ # split freqs
41
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
42
+
43
+ # loop over samples
44
+ output = []
45
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
46
+ seq_len = f * h * w
47
+
48
+ # precompute multipliers
49
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
50
+ s, n, -1, 2))
51
+ freqs_i = torch.cat([
52
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
53
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
54
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
55
+ ],
56
+ dim=-1).reshape(seq_len, 1, -1)
57
+
58
+ # apply rotary embedding
59
+ sp_size = get_sequence_parallel_world_size()
60
+ sp_rank = get_sequence_parallel_rank()
61
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
62
+ s_per_rank = s
63
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
64
+ s_per_rank), :, :]
65
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
66
+ x_i = torch.cat([x_i, x[i, s:]])
67
+
68
+ # append to collection
69
+ output.append(x_i)
70
+ return torch.stack(output)
71
+
72
+ def usp_attn_forward(self,
73
+ x,
74
+ seq_lens,
75
+ grid_sizes,
76
+ freqs,
77
+ dtype=torch.bfloat16):
78
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
79
+ half_dtypes = (torch.float16, torch.bfloat16)
80
+
81
+ def half(x):
82
+ return x if x.dtype in half_dtypes else x.to(dtype)
83
+
84
+ # query, key, value function
85
+ def qkv_fn(x):
86
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
87
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
88
+ v = self.v(x).view(b, s, n, d)
89
+ return q, k, v
90
+
91
+ q, k, v = qkv_fn(x)
92
+ q = rope_apply(q, grid_sizes, freqs)
93
+ k = rope_apply(k, grid_sizes, freqs)
94
+
95
+ # TODO: We should use unpaded q,k,v for attention.
96
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
97
+ # if k_lens is not None:
98
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
99
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
100
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
101
+
102
+ x = xFuserLongContextAttention()(
103
+ None,
104
+ query=half(q),
105
+ key=half(k),
106
+ value=half(v),
107
+ window_size=self.window_size)
108
+
109
+ # TODO: padding after attention.
110
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
111
+
112
+ # output
113
+ x = x.flatten(2)
114
+ x = self.o(x)
115
+ return x
videox_fun/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
2
+
3
+ from .cogvideox_transformer3d import CogVideoXTransformer3DModel
4
+ from .cogvideox_vae import AutoencoderKLCogVideoX
videox_fun/models/cache_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def get_teacache_coefficients(model_name):
6
+ if "wan2.1-t2v-1.3b" in model_name.lower() or "wan2.1-fun-1.3b" in model_name.lower():
7
+ return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
8
+ elif "wan2.1-t2v-14b" in model_name.lower():
9
+ return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
10
+ elif "wan2.1-i2v-14b-480p" in model_name.lower():
11
+ return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
12
+ elif "wan2.1-i2v-14b-720p" in model_name.lower() or "wan2.1-fun-14b" in model_name.lower():
13
+ return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
14
+ else:
15
+ print(f"The model {model_name} is not supported by TeaCache.")
16
+ return None
17
+
18
+
19
+ class TeaCache():
20
+ """
21
+ Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
22
+ the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
23
+ Please refer to:
24
+ 1. https://github.com/ali-vilab/TeaCache.
25
+ 2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
26
+ """
27
+ def __init__(
28
+ self,
29
+ coefficients: list[float],
30
+ num_steps: int,
31
+ rel_l1_thresh: float = 0.0,
32
+ num_skip_start_steps: int = 0,
33
+ offload: bool = True,
34
+ ):
35
+ if num_steps < 1:
36
+ raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
37
+ if rel_l1_thresh < 0:
38
+ raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
39
+ if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
40
+ raise ValueError(
41
+ "`num_skip_start_steps` must be great than or equal to 0 and "
42
+ f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
43
+ )
44
+ self.coefficients = coefficients
45
+ self.num_steps = num_steps
46
+ self.rel_l1_thresh = rel_l1_thresh
47
+ self.num_skip_start_steps = num_skip_start_steps
48
+ self.offload = offload
49
+ self.rescale_func = np.poly1d(self.coefficients)
50
+
51
+ self.cnt = 0
52
+ self.should_calc = True
53
+ self.accumulated_rel_l1_distance = 0
54
+ self.previous_modulated_input = None
55
+ # Some pipelines concatenate the unconditional and text guide in forward.
56
+ self.previous_residual = None
57
+ # Some pipelines perform forward propagation separately on the unconditional and text guide.
58
+ self.previous_residual_cond = None
59
+ self.previous_residual_uncond = None
60
+
61
+ @staticmethod
62
+ def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
63
+ rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
64
+
65
+ return rel_l1_distance.cpu().item()
66
+
67
+ def reset(self):
68
+ self.cnt = 0
69
+ self.should_calc = True
70
+ self.accumulated_rel_l1_distance = 0
71
+ self.previous_modulated_input = None
72
+ self.previous_residual = None
73
+ self.previous_residual_cond = None
74
+ self.previous_residual_uncond = None
videox_fun/models/cogvideox_transformer3d.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.attention import Attention, FeedForward
25
+ from diffusers.models.attention_processor import (
26
+ AttentionProcessor, CogVideoXAttnProcessor2_0,
27
+ FusedCogVideoXAttnProcessor2_0)
28
+ from diffusers.models.embeddings import (CogVideoXPatchEmbed,
29
+ TimestepEmbedding, Timesteps,
30
+ get_3d_sincos_pos_embed)
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
34
+ from diffusers.utils import is_torch_version, logging
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from torch import nn
37
+
38
+ from ..dist import (get_sequence_parallel_rank,
39
+ get_sequence_parallel_world_size,
40
+ get_sp_group,
41
+ xFuserLongContextAttention)
42
+ from ..dist.cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ class CogVideoXPatchEmbed(nn.Module):
49
+ def __init__(
50
+ self,
51
+ patch_size: int = 2,
52
+ patch_size_t: Optional[int] = None,
53
+ in_channels: int = 16,
54
+ embed_dim: int = 1920,
55
+ text_embed_dim: int = 4096,
56
+ bias: bool = True,
57
+ sample_width: int = 90,
58
+ sample_height: int = 60,
59
+ sample_frames: int = 49,
60
+ temporal_compression_ratio: int = 4,
61
+ max_text_seq_length: int = 226,
62
+ spatial_interpolation_scale: float = 1.875,
63
+ temporal_interpolation_scale: float = 1.0,
64
+ use_positional_embeddings: bool = True,
65
+ use_learned_positional_embeddings: bool = True,
66
+ ) -> None:
67
+ super().__init__()
68
+
69
+ post_patch_height = sample_height // patch_size
70
+ post_patch_width = sample_width // patch_size
71
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
72
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
73
+ self.post_patch_height = post_patch_height
74
+ self.post_patch_width = post_patch_width
75
+ self.post_time_compression_frames = post_time_compression_frames
76
+ self.patch_size = patch_size
77
+ self.patch_size_t = patch_size_t
78
+ self.embed_dim = embed_dim
79
+ self.sample_height = sample_height
80
+ self.sample_width = sample_width
81
+ self.sample_frames = sample_frames
82
+ self.temporal_compression_ratio = temporal_compression_ratio
83
+ self.max_text_seq_length = max_text_seq_length
84
+ self.spatial_interpolation_scale = spatial_interpolation_scale
85
+ self.temporal_interpolation_scale = temporal_interpolation_scale
86
+ self.use_positional_embeddings = use_positional_embeddings
87
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
88
+
89
+ if patch_size_t is None:
90
+ # CogVideoX 1.0 checkpoints
91
+ self.proj = nn.Conv2d(
92
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
93
+ )
94
+ else:
95
+ # CogVideoX 1.5 checkpoints
96
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
97
+
98
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
99
+
100
+ if use_positional_embeddings or use_learned_positional_embeddings:
101
+ persistent = use_learned_positional_embeddings
102
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
103
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
104
+
105
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
106
+ post_patch_height = sample_height // self.patch_size
107
+ post_patch_width = sample_width // self.patch_size
108
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
109
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
110
+
111
+ pos_embedding = get_3d_sincos_pos_embed(
112
+ self.embed_dim,
113
+ (post_patch_width, post_patch_height),
114
+ post_time_compression_frames,
115
+ self.spatial_interpolation_scale,
116
+ self.temporal_interpolation_scale,
117
+ )
118
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
119
+ joint_pos_embedding = torch.zeros(
120
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
121
+ )
122
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
123
+
124
+ return joint_pos_embedding
125
+
126
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
127
+ r"""
128
+ Args:
129
+ text_embeds (`torch.Tensor`):
130
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
131
+ image_embeds (`torch.Tensor`):
132
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
133
+ """
134
+ text_embeds = self.text_proj(text_embeds)
135
+
136
+ text_batch_size, text_seq_length, text_channels = text_embeds.shape
137
+ batch_size, num_frames, channels, height, width = image_embeds.shape
138
+
139
+ if self.patch_size_t is None:
140
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
141
+ image_embeds = self.proj(image_embeds)
142
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
143
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
144
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
145
+ else:
146
+ p = self.patch_size
147
+ p_t = self.patch_size_t
148
+
149
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
150
+ # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
151
+ image_embeds = image_embeds.reshape(
152
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
153
+ )
154
+ # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
155
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
156
+ image_embeds = self.proj(image_embeds)
157
+
158
+ embeds = torch.cat(
159
+ [text_embeds, image_embeds], dim=1
160
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
161
+
162
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
163
+ seq_length = height * width * num_frames // (self.patch_size**2)
164
+ # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
165
+ pos_embeds = self.pos_embedding
166
+ emb_size = embeds.size()[-1]
167
+ pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
168
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
169
+ pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
170
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
171
+ pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
172
+ pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
173
+ embeds = embeds + pos_embeds
174
+
175
+ return embeds
176
+
177
+ @maybe_allow_in_graph
178
+ class CogVideoXBlock(nn.Module):
179
+ r"""
180
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
181
+
182
+ Parameters:
183
+ dim (`int`):
184
+ The number of channels in the input and output.
185
+ num_attention_heads (`int`):
186
+ The number of heads to use for multi-head attention.
187
+ attention_head_dim (`int`):
188
+ The number of channels in each head.
189
+ time_embed_dim (`int`):
190
+ The number of channels in timestep embedding.
191
+ dropout (`float`, defaults to `0.0`):
192
+ The dropout probability to use.
193
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
194
+ Activation function to be used in feed-forward.
195
+ attention_bias (`bool`, defaults to `False`):
196
+ Whether or not to use bias in attention projection layers.
197
+ qk_norm (`bool`, defaults to `True`):
198
+ Whether or not to use normalization after query and key projections in Attention.
199
+ norm_elementwise_affine (`bool`, defaults to `True`):
200
+ Whether to use learnable elementwise affine parameters for normalization.
201
+ norm_eps (`float`, defaults to `1e-5`):
202
+ Epsilon value for normalization layers.
203
+ final_dropout (`bool` defaults to `False`):
204
+ Whether to apply a final dropout after the last feed-forward layer.
205
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
206
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
207
+ ff_bias (`bool`, defaults to `True`):
208
+ Whether or not to use bias in Feed-forward layer.
209
+ attention_out_bias (`bool`, defaults to `True`):
210
+ Whether or not to use bias in Attention output projection layer.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ dim: int,
216
+ num_attention_heads: int,
217
+ attention_head_dim: int,
218
+ time_embed_dim: int,
219
+ dropout: float = 0.0,
220
+ activation_fn: str = "gelu-approximate",
221
+ attention_bias: bool = False,
222
+ qk_norm: bool = True,
223
+ norm_elementwise_affine: bool = True,
224
+ norm_eps: float = 1e-5,
225
+ final_dropout: bool = True,
226
+ ff_inner_dim: Optional[int] = None,
227
+ ff_bias: bool = True,
228
+ attention_out_bias: bool = True,
229
+ ):
230
+ super().__init__()
231
+
232
+ # 1. Self Attention
233
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
234
+
235
+ self.attn1 = Attention(
236
+ query_dim=dim,
237
+ dim_head=attention_head_dim,
238
+ heads=num_attention_heads,
239
+ qk_norm="layer_norm" if qk_norm else None,
240
+ eps=1e-6,
241
+ bias=attention_bias,
242
+ out_bias=attention_out_bias,
243
+ processor=CogVideoXAttnProcessor2_0(),
244
+ )
245
+
246
+ # 2. Feed Forward
247
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
248
+
249
+ self.ff = FeedForward(
250
+ dim,
251
+ dropout=dropout,
252
+ activation_fn=activation_fn,
253
+ final_dropout=final_dropout,
254
+ inner_dim=ff_inner_dim,
255
+ bias=ff_bias,
256
+ )
257
+
258
+ def forward(
259
+ self,
260
+ hidden_states: torch.Tensor,
261
+ encoder_hidden_states: torch.Tensor,
262
+ temb: torch.Tensor,
263
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
264
+ ) -> torch.Tensor:
265
+ text_seq_length = encoder_hidden_states.size(1)
266
+
267
+ # norm & modulate
268
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
269
+ hidden_states, encoder_hidden_states, temb
270
+ )
271
+
272
+ # attention
273
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
274
+ hidden_states=norm_hidden_states,
275
+ encoder_hidden_states=norm_encoder_hidden_states,
276
+ image_rotary_emb=image_rotary_emb,
277
+ )
278
+
279
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
280
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
281
+
282
+ # norm & modulate
283
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
284
+ hidden_states, encoder_hidden_states, temb
285
+ )
286
+
287
+ # feed-forward
288
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
289
+ ff_output = self.ff(norm_hidden_states)
290
+
291
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
292
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
293
+
294
+ return hidden_states, encoder_hidden_states
295
+
296
+
297
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
298
+ """
299
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
300
+
301
+ Parameters:
302
+ num_attention_heads (`int`, defaults to `30`):
303
+ The number of heads to use for multi-head attention.
304
+ attention_head_dim (`int`, defaults to `64`):
305
+ The number of channels in each head.
306
+ in_channels (`int`, defaults to `16`):
307
+ The number of channels in the input.
308
+ out_channels (`int`, *optional*, defaults to `16`):
309
+ The number of channels in the output.
310
+ flip_sin_to_cos (`bool`, defaults to `True`):
311
+ Whether to flip the sin to cos in the time embedding.
312
+ time_embed_dim (`int`, defaults to `512`):
313
+ Output dimension of timestep embeddings.
314
+ text_embed_dim (`int`, defaults to `4096`):
315
+ Input dimension of text embeddings from the text encoder.
316
+ num_layers (`int`, defaults to `30`):
317
+ The number of layers of Transformer blocks to use.
318
+ dropout (`float`, defaults to `0.0`):
319
+ The dropout probability to use.
320
+ attention_bias (`bool`, defaults to `True`):
321
+ Whether or not to use bias in the attention projection layers.
322
+ sample_width (`int`, defaults to `90`):
323
+ The width of the input latents.
324
+ sample_height (`int`, defaults to `60`):
325
+ The height of the input latents.
326
+ sample_frames (`int`, defaults to `49`):
327
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
328
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
329
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
330
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
331
+ patch_size (`int`, defaults to `2`):
332
+ The size of the patches to use in the patch embedding layer.
333
+ temporal_compression_ratio (`int`, defaults to `4`):
334
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
335
+ max_text_seq_length (`int`, defaults to `226`):
336
+ The maximum sequence length of the input text embeddings.
337
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
338
+ Activation function to use in feed-forward.
339
+ timestep_activation_fn (`str`, defaults to `"silu"`):
340
+ Activation function to use when generating the timestep embeddings.
341
+ norm_elementwise_affine (`bool`, defaults to `True`):
342
+ Whether or not to use elementwise affine in normalization layers.
343
+ norm_eps (`float`, defaults to `1e-5`):
344
+ The epsilon value to use in normalization layers.
345
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
346
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
347
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
348
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
349
+ """
350
+
351
+ _supports_gradient_checkpointing = True
352
+
353
+ @register_to_config
354
+ def __init__(
355
+ self,
356
+ num_attention_heads: int = 30,
357
+ attention_head_dim: int = 64,
358
+ in_channels: int = 16,
359
+ out_channels: Optional[int] = 16,
360
+ flip_sin_to_cos: bool = True,
361
+ freq_shift: int = 0,
362
+ time_embed_dim: int = 512,
363
+ text_embed_dim: int = 4096,
364
+ num_layers: int = 30,
365
+ dropout: float = 0.0,
366
+ attention_bias: bool = True,
367
+ sample_width: int = 90,
368
+ sample_height: int = 60,
369
+ sample_frames: int = 49,
370
+ patch_size: int = 2,
371
+ patch_size_t: Optional[int] = None,
372
+ temporal_compression_ratio: int = 4,
373
+ max_text_seq_length: int = 226,
374
+ activation_fn: str = "gelu-approximate",
375
+ timestep_activation_fn: str = "silu",
376
+ norm_elementwise_affine: bool = True,
377
+ norm_eps: float = 1e-5,
378
+ spatial_interpolation_scale: float = 1.875,
379
+ temporal_interpolation_scale: float = 1.0,
380
+ use_rotary_positional_embeddings: bool = False,
381
+ use_learned_positional_embeddings: bool = False,
382
+ patch_bias: bool = True,
383
+ add_noise_in_inpaint_model: bool = False,
384
+ ):
385
+ super().__init__()
386
+ inner_dim = num_attention_heads * attention_head_dim
387
+ self.patch_size_t = patch_size_t
388
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
389
+ raise ValueError(
390
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
391
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
392
+ "issue at https://github.com/huggingface/diffusers/issues."
393
+ )
394
+
395
+ # 1. Patch embedding
396
+ self.patch_embed = CogVideoXPatchEmbed(
397
+ patch_size=patch_size,
398
+ patch_size_t=patch_size_t,
399
+ in_channels=in_channels,
400
+ embed_dim=inner_dim,
401
+ text_embed_dim=text_embed_dim,
402
+ bias=patch_bias,
403
+ sample_width=sample_width,
404
+ sample_height=sample_height,
405
+ sample_frames=sample_frames,
406
+ temporal_compression_ratio=temporal_compression_ratio,
407
+ max_text_seq_length=max_text_seq_length,
408
+ spatial_interpolation_scale=spatial_interpolation_scale,
409
+ temporal_interpolation_scale=temporal_interpolation_scale,
410
+ use_positional_embeddings=not use_rotary_positional_embeddings,
411
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
412
+ )
413
+ self.embedding_dropout = nn.Dropout(dropout)
414
+
415
+ # 2. Time embeddings
416
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
417
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
418
+
419
+ # 3. Define spatio-temporal transformers blocks
420
+ self.transformer_blocks = nn.ModuleList(
421
+ [
422
+ CogVideoXBlock(
423
+ dim=inner_dim,
424
+ num_attention_heads=num_attention_heads,
425
+ attention_head_dim=attention_head_dim,
426
+ time_embed_dim=time_embed_dim,
427
+ dropout=dropout,
428
+ activation_fn=activation_fn,
429
+ attention_bias=attention_bias,
430
+ norm_elementwise_affine=norm_elementwise_affine,
431
+ norm_eps=norm_eps,
432
+ )
433
+ for _ in range(num_layers)
434
+ ]
435
+ )
436
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
437
+
438
+ # 4. Output blocks
439
+ self.norm_out = AdaLayerNorm(
440
+ embedding_dim=time_embed_dim,
441
+ output_dim=2 * inner_dim,
442
+ norm_elementwise_affine=norm_elementwise_affine,
443
+ norm_eps=norm_eps,
444
+ chunk_dim=1,
445
+ )
446
+
447
+ if patch_size_t is None:
448
+ # For CogVideox 1.0
449
+ output_dim = patch_size * patch_size * out_channels
450
+ else:
451
+ # For CogVideoX 1.5
452
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
453
+
454
+ self.proj_out = nn.Linear(inner_dim, output_dim)
455
+
456
+ self.gradient_checkpointing = False
457
+ self.sp_world_size = 1
458
+ self.sp_world_rank = 0
459
+
460
+ def _set_gradient_checkpointing(self, module, value=False):
461
+ self.gradient_checkpointing = value
462
+
463
+ def enable_multi_gpus_inference(self,):
464
+ self.sp_world_size = get_sequence_parallel_world_size()
465
+ self.sp_world_rank = get_sequence_parallel_rank()
466
+ self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
467
+
468
+ @property
469
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
470
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
471
+ r"""
472
+ Returns:
473
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
474
+ indexed by its weight name.
475
+ """
476
+ # set recursively
477
+ processors = {}
478
+
479
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
480
+ if hasattr(module, "get_processor"):
481
+ processors[f"{name}.processor"] = module.get_processor()
482
+
483
+ for sub_name, child in module.named_children():
484
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
485
+
486
+ return processors
487
+
488
+ for name, module in self.named_children():
489
+ fn_recursive_add_processors(name, module, processors)
490
+
491
+ return processors
492
+
493
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
494
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
495
+ r"""
496
+ Sets the attention processor to use to compute attention.
497
+
498
+ Parameters:
499
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
500
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
501
+ for **all** `Attention` layers.
502
+
503
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
504
+ processor. This is strongly recommended when setting trainable attention processors.
505
+
506
+ """
507
+ count = len(self.attn_processors.keys())
508
+
509
+ if isinstance(processor, dict) and len(processor) != count:
510
+ raise ValueError(
511
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
512
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
513
+ )
514
+
515
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
516
+ if hasattr(module, "set_processor"):
517
+ if not isinstance(processor, dict):
518
+ module.set_processor(processor)
519
+ else:
520
+ module.set_processor(processor.pop(f"{name}.processor"))
521
+
522
+ for sub_name, child in module.named_children():
523
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
524
+
525
+ for name, module in self.named_children():
526
+ fn_recursive_attn_processor(name, module, processor)
527
+
528
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
529
+ def fuse_qkv_projections(self):
530
+ """
531
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
532
+ are fused. For cross-attention modules, key and value projection matrices are fused.
533
+
534
+ <Tip warning={true}>
535
+
536
+ This API is 🧪 experimental.
537
+
538
+ </Tip>
539
+ """
540
+ self.original_attn_processors = None
541
+
542
+ for _, attn_processor in self.attn_processors.items():
543
+ if "Added" in str(attn_processor.__class__.__name__):
544
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
545
+
546
+ self.original_attn_processors = self.attn_processors
547
+
548
+ for module in self.modules():
549
+ if isinstance(module, Attention):
550
+ module.fuse_projections(fuse=True)
551
+
552
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
553
+
554
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
555
+ def unfuse_qkv_projections(self):
556
+ """Disables the fused QKV projection if enabled.
557
+
558
+ <Tip warning={true}>
559
+
560
+ This API is 🧪 experimental.
561
+
562
+ </Tip>
563
+
564
+ """
565
+ if self.original_attn_processors is not None:
566
+ self.set_attn_processor(self.original_attn_processors)
567
+
568
+ def forward(
569
+ self,
570
+ hidden_states: torch.Tensor,
571
+ encoder_hidden_states: torch.Tensor,
572
+ timestep: Union[int, float, torch.LongTensor],
573
+ timestep_cond: Optional[torch.Tensor] = None,
574
+ inpaint_latents: Optional[torch.Tensor] = None,
575
+ control_latents: Optional[torch.Tensor] = None,
576
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
577
+ return_dict: bool = True,
578
+ ):
579
+ batch_size, num_frames, channels, height, width = hidden_states.shape
580
+ if num_frames == 1 and self.patch_size_t is not None:
581
+ hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
582
+ if inpaint_latents is not None:
583
+ inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
584
+ if control_latents is not None:
585
+ control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
586
+ local_num_frames = num_frames + 1
587
+ else:
588
+ local_num_frames = num_frames
589
+
590
+ # 1. Time embedding
591
+ timesteps = timestep
592
+ t_emb = self.time_proj(timesteps)
593
+
594
+ # timesteps does not contain any weights and will always return f32 tensors
595
+ # but time_embedding might actually be running in fp16. so we need to cast here.
596
+ # there might be better ways to encapsulate this.
597
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
598
+ emb = self.time_embedding(t_emb, timestep_cond)
599
+
600
+ # 2. Patch embedding
601
+ if inpaint_latents is not None:
602
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
603
+ if control_latents is not None:
604
+ hidden_states = torch.concat([hidden_states, control_latents], 2)
605
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
606
+ hidden_states = self.embedding_dropout(hidden_states)
607
+
608
+ text_seq_length = encoder_hidden_states.shape[1]
609
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
610
+ hidden_states = hidden_states[:, text_seq_length:]
611
+
612
+ # Context Parallel
613
+ if self.sp_world_size > 1:
614
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
615
+ if image_rotary_emb is not None:
616
+ image_rotary_emb = (
617
+ torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
618
+ torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
619
+ )
620
+
621
+ # 3. Transformer blocks
622
+ for i, block in enumerate(self.transformer_blocks):
623
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
624
+
625
+ def create_custom_forward(module):
626
+ def custom_forward(*inputs):
627
+ return module(*inputs)
628
+
629
+ return custom_forward
630
+
631
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
632
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
633
+ create_custom_forward(block),
634
+ hidden_states,
635
+ encoder_hidden_states,
636
+ emb,
637
+ image_rotary_emb,
638
+ **ckpt_kwargs,
639
+ )
640
+ else:
641
+ hidden_states, encoder_hidden_states = block(
642
+ hidden_states=hidden_states,
643
+ encoder_hidden_states=encoder_hidden_states,
644
+ temb=emb,
645
+ image_rotary_emb=image_rotary_emb,
646
+ )
647
+
648
+ if not self.config.use_rotary_positional_embeddings:
649
+ # CogVideoX-2B
650
+ hidden_states = self.norm_final(hidden_states)
651
+ else:
652
+ # CogVideoX-5B
653
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
654
+ hidden_states = self.norm_final(hidden_states)
655
+ hidden_states = hidden_states[:, text_seq_length:]
656
+
657
+ # 4. Final block
658
+ hidden_states = self.norm_out(hidden_states, temb=emb)
659
+ hidden_states = self.proj_out(hidden_states)
660
+
661
+ if self.sp_world_size > 1:
662
+ hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
663
+
664
+ # 5. Unpatchify
665
+ p = self.config.patch_size
666
+ p_t = self.config.patch_size_t
667
+
668
+ if p_t is None:
669
+ output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
670
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
671
+ else:
672
+ output = hidden_states.reshape(
673
+ batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
674
+ )
675
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
676
+
677
+ if num_frames == 1:
678
+ output = output[:, :num_frames, :]
679
+
680
+ if not return_dict:
681
+ return (output,)
682
+ return Transformer2DModelOutput(sample=output)
683
+
684
+ @classmethod
685
+ def from_pretrained(
686
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
687
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16, use_vae_mask=False, stack_mask=False,
688
+ ):
689
+ if subfolder is not None:
690
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
691
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
692
+
693
+ config_file = os.path.join(pretrained_model_path, 'config.json')
694
+ if not os.path.isfile(config_file):
695
+ raise RuntimeError(f"{config_file} does not exist")
696
+ with open(config_file, "r") as f:
697
+ config = json.load(f)
698
+
699
+ if use_vae_mask:
700
+ print('[DEBUG] use vae to encode mask')
701
+ config['in_channels'] = 48
702
+ elif stack_mask:
703
+ print('[DEBUG] use stacking mask')
704
+ config['in_channels'] = 36
705
+
706
+ from diffusers.utils import WEIGHTS_NAME
707
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
708
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
709
+
710
+ if "dict_mapping" in transformer_additional_kwargs.keys():
711
+ for key in transformer_additional_kwargs["dict_mapping"]:
712
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
713
+
714
+ if low_cpu_mem_usage:
715
+ try:
716
+ import re
717
+
718
+ from diffusers.models.modeling_utils import \
719
+ load_model_dict_into_meta
720
+ from diffusers.utils import is_accelerate_available
721
+ if is_accelerate_available():
722
+ import accelerate
723
+
724
+ # Instantiate model with empty weights
725
+ with accelerate.init_empty_weights():
726
+ model = cls.from_config(config, **transformer_additional_kwargs)
727
+
728
+ param_device = "cpu"
729
+ if os.path.exists(model_file):
730
+ state_dict = torch.load(model_file, map_location="cpu")
731
+ elif os.path.exists(model_file_safetensors):
732
+ from safetensors.torch import load_file, safe_open
733
+ state_dict = load_file(model_file_safetensors)
734
+ else:
735
+ from safetensors.torch import load_file, safe_open
736
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
737
+ state_dict = {}
738
+ for _model_file_safetensors in model_files_safetensors:
739
+ _state_dict = load_file(_model_file_safetensors)
740
+ for key in _state_dict:
741
+ state_dict[key] = _state_dict[key]
742
+ model._convert_deprecated_attention_blocks(state_dict)
743
+ # move the params from meta device to cpu
744
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
745
+ if len(missing_keys) > 0:
746
+ raise ValueError(
747
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
748
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
749
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
750
+ " those weights or else make sure your checkpoint file is correct."
751
+ )
752
+
753
+ unexpected_keys = load_model_dict_into_meta(
754
+ model,
755
+ state_dict,
756
+ device=param_device,
757
+ dtype=torch_dtype,
758
+ model_name_or_path=pretrained_model_path,
759
+ )
760
+
761
+ if cls._keys_to_ignore_on_load_unexpected is not None:
762
+ for pat in cls._keys_to_ignore_on_load_unexpected:
763
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
764
+
765
+ if len(unexpected_keys) > 0:
766
+ print(
767
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
768
+ )
769
+ return model
770
+ except Exception as e:
771
+ print(
772
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
773
+ )
774
+
775
+ model = cls.from_config(config, **transformer_additional_kwargs)
776
+ if os.path.exists(model_file):
777
+ state_dict = torch.load(model_file, map_location="cpu")
778
+ elif os.path.exists(model_file_safetensors):
779
+ from safetensors.torch import load_file, safe_open
780
+ state_dict = load_file(model_file_safetensors)
781
+ else:
782
+ from safetensors.torch import load_file, safe_open
783
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
784
+ state_dict = {}
785
+ for _model_file_safetensors in model_files_safetensors:
786
+ _state_dict = load_file(_model_file_safetensors)
787
+ for key in _state_dict:
788
+ state_dict[key] = _state_dict[key]
789
+
790
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
791
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
792
+ if len(new_shape) == 5:
793
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
794
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
795
+ elif len(new_shape) == 2:
796
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
797
+ if use_vae_mask:
798
+ print('[DEBUG] patch_embed.proj.weight size does not match due to vae-encoded mask')
799
+ latent_ch = 16
800
+ feat_scale = 8
801
+ feat_dim = int(latent_ch * feat_scale)
802
+ old_total_dim = state_dict['patch_embed.proj.weight'].size(1)
803
+ new_total_dim = model.state_dict()['patch_embed.proj.weight'].size(1)
804
+ model.state_dict()['patch_embed.proj.weight'][:, :feat_dim] = state_dict['patch_embed.proj.weight'][:, :feat_dim]
805
+ model.state_dict()['patch_embed.proj.weight'][:, -feat_dim:] = state_dict['patch_embed.proj.weight'][:, -feat_dim:]
806
+ for i in range(feat_dim, new_total_dim - feat_dim, feat_scale):
807
+ model.state_dict()['patch_embed.proj.weight'][:, i:i+feat_scale] = state_dict['patch_embed.proj.weight'][:, feat_dim:-feat_dim]
808
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
809
+ else:
810
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
811
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
812
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
813
+ else:
814
+ model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
815
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
816
+ else:
817
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
818
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
819
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
820
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
821
+ else:
822
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
823
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
824
+
825
+ tmp_state_dict = {}
826
+ for key in state_dict:
827
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
828
+ tmp_state_dict[key] = state_dict[key]
829
+ else:
830
+ print(key, "Size don't match, skip")
831
+
832
+ state_dict = tmp_state_dict
833
+
834
+ m, u = model.load_state_dict(state_dict, strict=False)
835
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
836
+ print(m)
837
+
838
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
839
+ print(f"### All Parameters: {sum(params) / 1e6} M")
840
+
841
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
842
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
843
+
844
+ model = model.to(torch_dtype)
845
+ return model
videox_fun/models/cogvideox_vae.py ADDED
@@ -0,0 +1,1675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import json
23
+ import os
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
27
+ from diffusers.utils import logging
28
+ from diffusers.utils.accelerate_utils import apply_forward_hook
29
+ from diffusers.models.activations import get_activation
30
+ from diffusers.models.downsampling import CogVideoXDownsample3D
31
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.upsampling import CogVideoXUpsample3D
34
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ class CogVideoXSafeConv3d(nn.Conv3d):
41
+ r"""
42
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
43
+ """
44
+
45
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
46
+ memory_count = (
47
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
48
+ )
49
+
50
+ # Set to 2GB, suitable for CuDNN
51
+ if memory_count > 2:
52
+ kernel_size = self.kernel_size[0]
53
+ part_num = int(memory_count / 2) + 1
54
+ input_chunks = torch.chunk(input, part_num, dim=2)
55
+
56
+ if kernel_size > 1:
57
+ input_chunks = [input_chunks[0]] + [
58
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
59
+ for i in range(1, len(input_chunks))
60
+ ]
61
+
62
+ output_chunks = []
63
+ for input_chunk in input_chunks:
64
+ output_chunks.append(super().forward(input_chunk))
65
+ output = torch.cat(output_chunks, dim=2)
66
+ return output
67
+ else:
68
+ return super().forward(input)
69
+
70
+
71
+ class CogVideoXCausalConv3d(nn.Module):
72
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
73
+
74
+ Args:
75
+ in_channels (`int`): Number of channels in the input tensor.
76
+ out_channels (`int`): Number of output channels produced by the convolution.
77
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
78
+ stride (`int`, defaults to `1`): Stride of the convolution.
79
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
80
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: int,
87
+ kernel_size: Union[int, Tuple[int, int, int]],
88
+ stride: int = 1,
89
+ dilation: int = 1,
90
+ pad_mode: str = "constant",
91
+ ):
92
+ super().__init__()
93
+
94
+ if isinstance(kernel_size, int):
95
+ kernel_size = (kernel_size,) * 3
96
+
97
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
98
+
99
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
100
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
101
+ time_pad = time_kernel_size - 1
102
+ height_pad = (height_kernel_size - 1) // 2
103
+ width_pad = (width_kernel_size - 1) // 2
104
+
105
+ self.pad_mode = pad_mode
106
+ self.height_pad = height_pad
107
+ self.width_pad = width_pad
108
+ self.time_pad = time_pad
109
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
110
+
111
+ self.temporal_dim = 2
112
+ self.time_kernel_size = time_kernel_size
113
+
114
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
115
+ dilation = (dilation, 1, 1)
116
+ self.conv = CogVideoXSafeConv3d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=kernel_size,
120
+ stride=stride,
121
+ dilation=dilation,
122
+ )
123
+
124
+ def fake_context_parallel_forward(
125
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
126
+ ) -> torch.Tensor:
127
+ if self.pad_mode == "replicate":
128
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
129
+ else:
130
+ kernel_size = self.time_kernel_size
131
+ if kernel_size > 1:
132
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
133
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
134
+ return inputs
135
+
136
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
137
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
138
+
139
+ if self.pad_mode == "replicate":
140
+ conv_cache = None
141
+ else:
142
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
143
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
144
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
145
+
146
+ output = self.conv(inputs)
147
+ return output, conv_cache
148
+
149
+
150
+ class CogVideoXSpatialNorm3D(nn.Module):
151
+ r"""
152
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
153
+ to 3D-video like data.
154
+
155
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
156
+
157
+ Args:
158
+ f_channels (`int`):
159
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
160
+ zq_channels (`int`):
161
+ The number of channels for the quantized vector as described in the paper.
162
+ groups (`int`):
163
+ Number of groups to separate the channels into for group normalization.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ f_channels: int,
169
+ zq_channels: int,
170
+ groups: int = 32,
171
+ ):
172
+ super().__init__()
173
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
174
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
175
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
176
+
177
+ def forward(
178
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
179
+ ) -> torch.Tensor:
180
+ new_conv_cache = {}
181
+ conv_cache = conv_cache or {}
182
+
183
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
184
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
185
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
186
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
187
+ z_first = F.interpolate(z_first, size=f_first_size)
188
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
189
+ zq = torch.cat([z_first, z_rest], dim=2)
190
+ else:
191
+ zq = F.interpolate(zq, size=f.shape[-3:])
192
+
193
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
194
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
195
+
196
+ norm_f = self.norm_layer(f)
197
+ new_f = norm_f * conv_y + conv_b
198
+ return new_f, new_conv_cache
199
+
200
+
201
+ class CogVideoXUpsample3D(nn.Module):
202
+ r"""
203
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
204
+
205
+ Args:
206
+ in_channels (`int`):
207
+ Number of channels in the input image.
208
+ out_channels (`int`):
209
+ Number of channels produced by the convolution.
210
+ kernel_size (`int`, defaults to `3`):
211
+ Size of the convolving kernel.
212
+ stride (`int`, defaults to `1`):
213
+ Stride of the convolution.
214
+ padding (`int`, defaults to `1`):
215
+ Padding added to all four sides of the input.
216
+ compress_time (`bool`, defaults to `False`):
217
+ Whether or not to compress the time dimension.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ in_channels: int,
223
+ out_channels: int,
224
+ kernel_size: int = 3,
225
+ stride: int = 1,
226
+ padding: int = 1,
227
+ compress_time: bool = False,
228
+ ) -> None:
229
+ super().__init__()
230
+
231
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
232
+ self.compress_time = compress_time
233
+
234
+ self.auto_split_process = True
235
+ self.first_frame_flag = False
236
+
237
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
238
+ if self.compress_time:
239
+ if self.auto_split_process:
240
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
241
+ # split first frame
242
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
243
+
244
+ x_first = F.interpolate(x_first, scale_factor=2.0)
245
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
246
+ x_first = x_first[:, :, None, :, :]
247
+ inputs = torch.cat([x_first, x_rest], dim=2)
248
+ elif inputs.shape[2] > 1:
249
+ inputs = F.interpolate(inputs, scale_factor=2.0)
250
+ else:
251
+ inputs = inputs.squeeze(2)
252
+ inputs = F.interpolate(inputs, scale_factor=2.0)
253
+ inputs = inputs[:, :, None, :, :]
254
+ else:
255
+ if self.first_frame_flag:
256
+ inputs = inputs.squeeze(2)
257
+ inputs = F.interpolate(inputs, scale_factor=2.0)
258
+ inputs = inputs[:, :, None, :, :]
259
+ else:
260
+ inputs = F.interpolate(inputs, scale_factor=2.0)
261
+ else:
262
+ # only interpolate 2D
263
+ b, c, t, h, w = inputs.shape
264
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
265
+ inputs = F.interpolate(inputs, scale_factor=2.0)
266
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
267
+
268
+ b, c, t, h, w = inputs.shape
269
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
270
+ inputs = self.conv(inputs)
271
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
272
+
273
+ return inputs
274
+
275
+
276
+ class CogVideoXResnetBlock3D(nn.Module):
277
+ r"""
278
+ A 3D ResNet block used in the CogVideoX model.
279
+
280
+ Args:
281
+ in_channels (`int`):
282
+ Number of input channels.
283
+ out_channels (`int`, *optional*):
284
+ Number of output channels. If None, defaults to `in_channels`.
285
+ dropout (`float`, defaults to `0.0`):
286
+ Dropout rate.
287
+ temb_channels (`int`, defaults to `512`):
288
+ Number of time embedding channels.
289
+ groups (`int`, defaults to `32`):
290
+ Number of groups to separate the channels into for group normalization.
291
+ eps (`float`, defaults to `1e-6`):
292
+ Epsilon value for normalization layers.
293
+ non_linearity (`str`, defaults to `"swish"`):
294
+ Activation function to use.
295
+ conv_shortcut (bool, defaults to `False`):
296
+ Whether or not to use a convolution shortcut.
297
+ spatial_norm_dim (`int`, *optional*):
298
+ The dimension to use for spatial norm if it is to be used instead of group norm.
299
+ pad_mode (str, defaults to `"first"`):
300
+ Padding mode.
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ in_channels: int,
306
+ out_channels: Optional[int] = None,
307
+ dropout: float = 0.0,
308
+ temb_channels: int = 512,
309
+ groups: int = 32,
310
+ eps: float = 1e-6,
311
+ non_linearity: str = "swish",
312
+ conv_shortcut: bool = False,
313
+ spatial_norm_dim: Optional[int] = None,
314
+ pad_mode: str = "first",
315
+ ):
316
+ super().__init__()
317
+
318
+ out_channels = out_channels or in_channels
319
+
320
+ self.in_channels = in_channels
321
+ self.out_channels = out_channels
322
+ self.nonlinearity = get_activation(non_linearity)
323
+ self.use_conv_shortcut = conv_shortcut
324
+ self.spatial_norm_dim = spatial_norm_dim
325
+
326
+ if spatial_norm_dim is None:
327
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
328
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
329
+ else:
330
+ self.norm1 = CogVideoXSpatialNorm3D(
331
+ f_channels=in_channels,
332
+ zq_channels=spatial_norm_dim,
333
+ groups=groups,
334
+ )
335
+ self.norm2 = CogVideoXSpatialNorm3D(
336
+ f_channels=out_channels,
337
+ zq_channels=spatial_norm_dim,
338
+ groups=groups,
339
+ )
340
+
341
+ self.conv1 = CogVideoXCausalConv3d(
342
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
343
+ )
344
+
345
+ if temb_channels > 0:
346
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
347
+
348
+ self.dropout = nn.Dropout(dropout)
349
+ self.conv2 = CogVideoXCausalConv3d(
350
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
351
+ )
352
+
353
+ if self.in_channels != self.out_channels:
354
+ if self.use_conv_shortcut:
355
+ self.conv_shortcut = CogVideoXCausalConv3d(
356
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
357
+ )
358
+ else:
359
+ self.conv_shortcut = CogVideoXSafeConv3d(
360
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
361
+ )
362
+
363
+ def forward(
364
+ self,
365
+ inputs: torch.Tensor,
366
+ temb: Optional[torch.Tensor] = None,
367
+ zq: Optional[torch.Tensor] = None,
368
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
369
+ ) -> torch.Tensor:
370
+ new_conv_cache = {}
371
+ conv_cache = conv_cache or {}
372
+
373
+ hidden_states = inputs
374
+
375
+ if zq is not None:
376
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
377
+ else:
378
+ hidden_states = self.norm1(hidden_states)
379
+
380
+ hidden_states = self.nonlinearity(hidden_states)
381
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
382
+
383
+ if temb is not None:
384
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
385
+
386
+ if zq is not None:
387
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
388
+ else:
389
+ hidden_states = self.norm2(hidden_states)
390
+
391
+ hidden_states = self.nonlinearity(hidden_states)
392
+ hidden_states = self.dropout(hidden_states)
393
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
394
+
395
+ if self.in_channels != self.out_channels:
396
+ if self.use_conv_shortcut:
397
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
398
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
399
+ )
400
+ else:
401
+ inputs = self.conv_shortcut(inputs)
402
+
403
+ hidden_states = hidden_states + inputs
404
+ return hidden_states, new_conv_cache
405
+
406
+
407
+ class CogVideoXDownBlock3D(nn.Module):
408
+ r"""
409
+ A downsampling block used in the CogVideoX model.
410
+
411
+ Args:
412
+ in_channels (`int`):
413
+ Number of input channels.
414
+ out_channels (`int`, *optional*):
415
+ Number of output channels. If None, defaults to `in_channels`.
416
+ temb_channels (`int`, defaults to `512`):
417
+ Number of time embedding channels.
418
+ num_layers (`int`, defaults to `1`):
419
+ Number of resnet layers.
420
+ dropout (`float`, defaults to `0.0`):
421
+ Dropout rate.
422
+ resnet_eps (`float`, defaults to `1e-6`):
423
+ Epsilon value for normalization layers.
424
+ resnet_act_fn (`str`, defaults to `"swish"`):
425
+ Activation function to use.
426
+ resnet_groups (`int`, defaults to `32`):
427
+ Number of groups to separate the channels into for group normalization.
428
+ add_downsample (`bool`, defaults to `True`):
429
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
430
+ compress_time (`bool`, defaults to `False`):
431
+ Whether or not to downsample across temporal dimension.
432
+ pad_mode (str, defaults to `"first"`):
433
+ Padding mode.
434
+ """
435
+
436
+ _supports_gradient_checkpointing = True
437
+
438
+ def __init__(
439
+ self,
440
+ in_channels: int,
441
+ out_channels: int,
442
+ temb_channels: int,
443
+ dropout: float = 0.0,
444
+ num_layers: int = 1,
445
+ resnet_eps: float = 1e-6,
446
+ resnet_act_fn: str = "swish",
447
+ resnet_groups: int = 32,
448
+ add_downsample: bool = True,
449
+ downsample_padding: int = 0,
450
+ compress_time: bool = False,
451
+ pad_mode: str = "first",
452
+ ):
453
+ super().__init__()
454
+
455
+ resnets = []
456
+ for i in range(num_layers):
457
+ in_channel = in_channels if i == 0 else out_channels
458
+ resnets.append(
459
+ CogVideoXResnetBlock3D(
460
+ in_channels=in_channel,
461
+ out_channels=out_channels,
462
+ dropout=dropout,
463
+ temb_channels=temb_channels,
464
+ groups=resnet_groups,
465
+ eps=resnet_eps,
466
+ non_linearity=resnet_act_fn,
467
+ pad_mode=pad_mode,
468
+ )
469
+ )
470
+
471
+ self.resnets = nn.ModuleList(resnets)
472
+ self.downsamplers = None
473
+
474
+ if add_downsample:
475
+ self.downsamplers = nn.ModuleList(
476
+ [
477
+ CogVideoXDownsample3D(
478
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
479
+ )
480
+ ]
481
+ )
482
+
483
+ self.gradient_checkpointing = False
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ temb: Optional[torch.Tensor] = None,
489
+ zq: Optional[torch.Tensor] = None,
490
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
491
+ ) -> torch.Tensor:
492
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
493
+
494
+ new_conv_cache = {}
495
+ conv_cache = conv_cache or {}
496
+
497
+ for i, resnet in enumerate(self.resnets):
498
+ conv_cache_key = f"resnet_{i}"
499
+
500
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
501
+
502
+ def create_custom_forward(module):
503
+ def create_forward(*inputs):
504
+ return module(*inputs)
505
+
506
+ return create_forward
507
+
508
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
509
+ create_custom_forward(resnet),
510
+ hidden_states,
511
+ temb,
512
+ zq,
513
+ conv_cache.get(conv_cache_key),
514
+ )
515
+ else:
516
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
517
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
518
+ )
519
+
520
+ if self.downsamplers is not None:
521
+ for downsampler in self.downsamplers:
522
+ hidden_states = downsampler(hidden_states)
523
+
524
+ return hidden_states, new_conv_cache
525
+
526
+
527
+ class CogVideoXMidBlock3D(nn.Module):
528
+ r"""
529
+ A middle block used in the CogVideoX model.
530
+
531
+ Args:
532
+ in_channels (`int`):
533
+ Number of input channels.
534
+ temb_channels (`int`, defaults to `512`):
535
+ Number of time embedding channels.
536
+ dropout (`float`, defaults to `0.0`):
537
+ Dropout rate.
538
+ num_layers (`int`, defaults to `1`):
539
+ Number of resnet layers.
540
+ resnet_eps (`float`, defaults to `1e-6`):
541
+ Epsilon value for normalization layers.
542
+ resnet_act_fn (`str`, defaults to `"swish"`):
543
+ Activation function to use.
544
+ resnet_groups (`int`, defaults to `32`):
545
+ Number of groups to separate the channels into for group normalization.
546
+ spatial_norm_dim (`int`, *optional*):
547
+ The dimension to use for spatial norm if it is to be used instead of group norm.
548
+ pad_mode (str, defaults to `"first"`):
549
+ Padding mode.
550
+ """
551
+
552
+ _supports_gradient_checkpointing = True
553
+
554
+ def __init__(
555
+ self,
556
+ in_channels: int,
557
+ temb_channels: int,
558
+ dropout: float = 0.0,
559
+ num_layers: int = 1,
560
+ resnet_eps: float = 1e-6,
561
+ resnet_act_fn: str = "swish",
562
+ resnet_groups: int = 32,
563
+ spatial_norm_dim: Optional[int] = None,
564
+ pad_mode: str = "first",
565
+ ):
566
+ super().__init__()
567
+
568
+ resnets = []
569
+ for _ in range(num_layers):
570
+ resnets.append(
571
+ CogVideoXResnetBlock3D(
572
+ in_channels=in_channels,
573
+ out_channels=in_channels,
574
+ dropout=dropout,
575
+ temb_channels=temb_channels,
576
+ groups=resnet_groups,
577
+ eps=resnet_eps,
578
+ spatial_norm_dim=spatial_norm_dim,
579
+ non_linearity=resnet_act_fn,
580
+ pad_mode=pad_mode,
581
+ )
582
+ )
583
+ self.resnets = nn.ModuleList(resnets)
584
+
585
+ self.gradient_checkpointing = False
586
+
587
+ def forward(
588
+ self,
589
+ hidden_states: torch.Tensor,
590
+ temb: Optional[torch.Tensor] = None,
591
+ zq: Optional[torch.Tensor] = None,
592
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
593
+ ) -> torch.Tensor:
594
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
595
+
596
+ new_conv_cache = {}
597
+ conv_cache = conv_cache or {}
598
+
599
+ for i, resnet in enumerate(self.resnets):
600
+ conv_cache_key = f"resnet_{i}"
601
+
602
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
603
+
604
+ def create_custom_forward(module):
605
+ def create_forward(*inputs):
606
+ return module(*inputs)
607
+
608
+ return create_forward
609
+
610
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
611
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
612
+ )
613
+ else:
614
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
615
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
616
+ )
617
+
618
+ return hidden_states, new_conv_cache
619
+
620
+
621
+ class CogVideoXUpBlock3D(nn.Module):
622
+ r"""
623
+ An upsampling block used in the CogVideoX model.
624
+
625
+ Args:
626
+ in_channels (`int`):
627
+ Number of input channels.
628
+ out_channels (`int`, *optional*):
629
+ Number of output channels. If None, defaults to `in_channels`.
630
+ temb_channels (`int`, defaults to `512`):
631
+ Number of time embedding channels.
632
+ dropout (`float`, defaults to `0.0`):
633
+ Dropout rate.
634
+ num_layers (`int`, defaults to `1`):
635
+ Number of resnet layers.
636
+ resnet_eps (`float`, defaults to `1e-6`):
637
+ Epsilon value for normalization layers.
638
+ resnet_act_fn (`str`, defaults to `"swish"`):
639
+ Activation function to use.
640
+ resnet_groups (`int`, defaults to `32`):
641
+ Number of groups to separate the channels into for group normalization.
642
+ spatial_norm_dim (`int`, defaults to `16`):
643
+ The dimension to use for spatial norm if it is to be used instead of group norm.
644
+ add_upsample (`bool`, defaults to `True`):
645
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
646
+ compress_time (`bool`, defaults to `False`):
647
+ Whether or not to downsample across temporal dimension.
648
+ pad_mode (str, defaults to `"first"`):
649
+ Padding mode.
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ in_channels: int,
655
+ out_channels: int,
656
+ temb_channels: int,
657
+ dropout: float = 0.0,
658
+ num_layers: int = 1,
659
+ resnet_eps: float = 1e-6,
660
+ resnet_act_fn: str = "swish",
661
+ resnet_groups: int = 32,
662
+ spatial_norm_dim: int = 16,
663
+ add_upsample: bool = True,
664
+ upsample_padding: int = 1,
665
+ compress_time: bool = False,
666
+ pad_mode: str = "first",
667
+ ):
668
+ super().__init__()
669
+
670
+ resnets = []
671
+ for i in range(num_layers):
672
+ in_channel = in_channels if i == 0 else out_channels
673
+ resnets.append(
674
+ CogVideoXResnetBlock3D(
675
+ in_channels=in_channel,
676
+ out_channels=out_channels,
677
+ dropout=dropout,
678
+ temb_channels=temb_channels,
679
+ groups=resnet_groups,
680
+ eps=resnet_eps,
681
+ non_linearity=resnet_act_fn,
682
+ spatial_norm_dim=spatial_norm_dim,
683
+ pad_mode=pad_mode,
684
+ )
685
+ )
686
+
687
+ self.resnets = nn.ModuleList(resnets)
688
+ self.upsamplers = None
689
+
690
+ if add_upsample:
691
+ self.upsamplers = nn.ModuleList(
692
+ [
693
+ CogVideoXUpsample3D(
694
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
695
+ )
696
+ ]
697
+ )
698
+
699
+ self.gradient_checkpointing = False
700
+
701
+ def forward(
702
+ self,
703
+ hidden_states: torch.Tensor,
704
+ temb: Optional[torch.Tensor] = None,
705
+ zq: Optional[torch.Tensor] = None,
706
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
707
+ ) -> torch.Tensor:
708
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
709
+
710
+ new_conv_cache = {}
711
+ conv_cache = conv_cache or {}
712
+
713
+ for i, resnet in enumerate(self.resnets):
714
+ conv_cache_key = f"resnet_{i}"
715
+
716
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
717
+
718
+ def create_custom_forward(module):
719
+ def create_forward(*inputs):
720
+ return module(*inputs)
721
+
722
+ return create_forward
723
+
724
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
725
+ create_custom_forward(resnet),
726
+ hidden_states,
727
+ temb,
728
+ zq,
729
+ conv_cache.get(conv_cache_key),
730
+ )
731
+ else:
732
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
733
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
734
+ )
735
+
736
+ if self.upsamplers is not None:
737
+ for upsampler in self.upsamplers:
738
+ hidden_states = upsampler(hidden_states)
739
+
740
+ return hidden_states, new_conv_cache
741
+
742
+
743
+ class CogVideoXEncoder3D(nn.Module):
744
+ r"""
745
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
746
+
747
+ Args:
748
+ in_channels (`int`, *optional*, defaults to 3):
749
+ The number of input channels.
750
+ out_channels (`int`, *optional*, defaults to 3):
751
+ The number of output channels.
752
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
753
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
754
+ options.
755
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
756
+ The number of output channels for each block.
757
+ act_fn (`str`, *optional*, defaults to `"silu"`):
758
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
759
+ layers_per_block (`int`, *optional*, defaults to 2):
760
+ The number of layers per block.
761
+ norm_num_groups (`int`, *optional*, defaults to 32):
762
+ The number of groups for normalization.
763
+ """
764
+
765
+ _supports_gradient_checkpointing = True
766
+
767
+ def __init__(
768
+ self,
769
+ in_channels: int = 3,
770
+ out_channels: int = 16,
771
+ down_block_types: Tuple[str, ...] = (
772
+ "CogVideoXDownBlock3D",
773
+ "CogVideoXDownBlock3D",
774
+ "CogVideoXDownBlock3D",
775
+ "CogVideoXDownBlock3D",
776
+ ),
777
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
778
+ layers_per_block: int = 3,
779
+ act_fn: str = "silu",
780
+ norm_eps: float = 1e-6,
781
+ norm_num_groups: int = 32,
782
+ dropout: float = 0.0,
783
+ pad_mode: str = "first",
784
+ temporal_compression_ratio: float = 4,
785
+ ):
786
+ super().__init__()
787
+
788
+ # log2 of temporal_compress_times
789
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
790
+
791
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
792
+ self.down_blocks = nn.ModuleList([])
793
+
794
+ # down blocks
795
+ output_channel = block_out_channels[0]
796
+ for i, down_block_type in enumerate(down_block_types):
797
+ input_channel = output_channel
798
+ output_channel = block_out_channels[i]
799
+ is_final_block = i == len(block_out_channels) - 1
800
+ compress_time = i < temporal_compress_level
801
+
802
+ if down_block_type == "CogVideoXDownBlock3D":
803
+ down_block = CogVideoXDownBlock3D(
804
+ in_channels=input_channel,
805
+ out_channels=output_channel,
806
+ temb_channels=0,
807
+ dropout=dropout,
808
+ num_layers=layers_per_block,
809
+ resnet_eps=norm_eps,
810
+ resnet_act_fn=act_fn,
811
+ resnet_groups=norm_num_groups,
812
+ add_downsample=not is_final_block,
813
+ compress_time=compress_time,
814
+ )
815
+ else:
816
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
817
+
818
+ self.down_blocks.append(down_block)
819
+
820
+ # mid block
821
+ self.mid_block = CogVideoXMidBlock3D(
822
+ in_channels=block_out_channels[-1],
823
+ temb_channels=0,
824
+ dropout=dropout,
825
+ num_layers=2,
826
+ resnet_eps=norm_eps,
827
+ resnet_act_fn=act_fn,
828
+ resnet_groups=norm_num_groups,
829
+ pad_mode=pad_mode,
830
+ )
831
+
832
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
833
+ self.conv_act = nn.SiLU()
834
+ self.conv_out = CogVideoXCausalConv3d(
835
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
836
+ )
837
+
838
+ self.gradient_checkpointing = False
839
+
840
+ def forward(
841
+ self,
842
+ sample: torch.Tensor,
843
+ temb: Optional[torch.Tensor] = None,
844
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
845
+ ) -> torch.Tensor:
846
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
847
+
848
+ new_conv_cache = {}
849
+ conv_cache = conv_cache or {}
850
+
851
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
+
853
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
854
+
855
+ def create_custom_forward(module):
856
+ def custom_forward(*inputs):
857
+ return module(*inputs)
858
+
859
+ return custom_forward
860
+
861
+ # 1. Down
862
+ for i, down_block in enumerate(self.down_blocks):
863
+ conv_cache_key = f"down_block_{i}"
864
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
865
+ create_custom_forward(down_block),
866
+ hidden_states,
867
+ temb,
868
+ None,
869
+ conv_cache.get(conv_cache_key),
870
+ )
871
+
872
+ # 2. Mid
873
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
874
+ create_custom_forward(self.mid_block),
875
+ hidden_states,
876
+ temb,
877
+ None,
878
+ conv_cache.get("mid_block"),
879
+ )
880
+ else:
881
+ # 1. Down
882
+ for i, down_block in enumerate(self.down_blocks):
883
+ conv_cache_key = f"down_block_{i}"
884
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
885
+ hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
886
+ )
887
+
888
+ # 2. Mid
889
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
890
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
891
+ )
892
+
893
+ # 3. Post-process
894
+ hidden_states = self.norm_out(hidden_states)
895
+ hidden_states = self.conv_act(hidden_states)
896
+
897
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
898
+
899
+ return hidden_states, new_conv_cache
900
+
901
+
902
+ class CogVideoXDecoder3D(nn.Module):
903
+ r"""
904
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
905
+ sample.
906
+
907
+ Args:
908
+ in_channels (`int`, *optional*, defaults to 3):
909
+ The number of input channels.
910
+ out_channels (`int`, *optional*, defaults to 3):
911
+ The number of output channels.
912
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
913
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
914
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
915
+ The number of output channels for each block.
916
+ act_fn (`str`, *optional*, defaults to `"silu"`):
917
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
918
+ layers_per_block (`int`, *optional*, defaults to 2):
919
+ The number of layers per block.
920
+ norm_num_groups (`int`, *optional*, defaults to 32):
921
+ The number of groups for normalization.
922
+ """
923
+
924
+ _supports_gradient_checkpointing = True
925
+
926
+ def __init__(
927
+ self,
928
+ in_channels: int = 16,
929
+ out_channels: int = 3,
930
+ up_block_types: Tuple[str, ...] = (
931
+ "CogVideoXUpBlock3D",
932
+ "CogVideoXUpBlock3D",
933
+ "CogVideoXUpBlock3D",
934
+ "CogVideoXUpBlock3D",
935
+ ),
936
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
937
+ layers_per_block: int = 3,
938
+ act_fn: str = "silu",
939
+ norm_eps: float = 1e-6,
940
+ norm_num_groups: int = 32,
941
+ dropout: float = 0.0,
942
+ pad_mode: str = "first",
943
+ temporal_compression_ratio: float = 4,
944
+ ):
945
+ super().__init__()
946
+
947
+ reversed_block_out_channels = list(reversed(block_out_channels))
948
+
949
+ self.conv_in = CogVideoXCausalConv3d(
950
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
951
+ )
952
+
953
+ # mid block
954
+ self.mid_block = CogVideoXMidBlock3D(
955
+ in_channels=reversed_block_out_channels[0],
956
+ temb_channels=0,
957
+ num_layers=2,
958
+ resnet_eps=norm_eps,
959
+ resnet_act_fn=act_fn,
960
+ resnet_groups=norm_num_groups,
961
+ spatial_norm_dim=in_channels,
962
+ pad_mode=pad_mode,
963
+ )
964
+
965
+ # up blocks
966
+ self.up_blocks = nn.ModuleList([])
967
+
968
+ output_channel = reversed_block_out_channels[0]
969
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
970
+
971
+ for i, up_block_type in enumerate(up_block_types):
972
+ prev_output_channel = output_channel
973
+ output_channel = reversed_block_out_channels[i]
974
+ is_final_block = i == len(block_out_channels) - 1
975
+ compress_time = i < temporal_compress_level
976
+
977
+ if up_block_type == "CogVideoXUpBlock3D":
978
+ up_block = CogVideoXUpBlock3D(
979
+ in_channels=prev_output_channel,
980
+ out_channels=output_channel,
981
+ temb_channels=0,
982
+ dropout=dropout,
983
+ num_layers=layers_per_block + 1,
984
+ resnet_eps=norm_eps,
985
+ resnet_act_fn=act_fn,
986
+ resnet_groups=norm_num_groups,
987
+ spatial_norm_dim=in_channels,
988
+ add_upsample=not is_final_block,
989
+ compress_time=compress_time,
990
+ pad_mode=pad_mode,
991
+ )
992
+ prev_output_channel = output_channel
993
+ else:
994
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
995
+
996
+ self.up_blocks.append(up_block)
997
+
998
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
999
+ self.conv_act = nn.SiLU()
1000
+ self.conv_out = CogVideoXCausalConv3d(
1001
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
1002
+ )
1003
+
1004
+ self.gradient_checkpointing = False
1005
+
1006
+ def forward(
1007
+ self,
1008
+ sample: torch.Tensor,
1009
+ temb: Optional[torch.Tensor] = None,
1010
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
1011
+ ) -> torch.Tensor:
1012
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
1013
+
1014
+ new_conv_cache = {}
1015
+ conv_cache = conv_cache or {}
1016
+
1017
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
1018
+
1019
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1020
+
1021
+ def create_custom_forward(module):
1022
+ def custom_forward(*inputs):
1023
+ return module(*inputs)
1024
+
1025
+ return custom_forward
1026
+
1027
+ # 1. Mid
1028
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
1029
+ create_custom_forward(self.mid_block),
1030
+ hidden_states,
1031
+ temb,
1032
+ sample,
1033
+ conv_cache.get("mid_block"),
1034
+ )
1035
+
1036
+ # 2. Up
1037
+ for i, up_block in enumerate(self.up_blocks):
1038
+ conv_cache_key = f"up_block_{i}"
1039
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
1040
+ create_custom_forward(up_block),
1041
+ hidden_states,
1042
+ temb,
1043
+ sample,
1044
+ conv_cache.get(conv_cache_key),
1045
+ )
1046
+ else:
1047
+ # 1. Mid
1048
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
1049
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
1050
+ )
1051
+
1052
+ # 2. Up
1053
+ for i, up_block in enumerate(self.up_blocks):
1054
+ conv_cache_key = f"up_block_{i}"
1055
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
1056
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
1057
+ )
1058
+
1059
+ # 3. Post-process
1060
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
1061
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
1062
+ )
1063
+ hidden_states = self.conv_act(hidden_states)
1064
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
1065
+
1066
+ return hidden_states, new_conv_cache
1067
+
1068
+
1069
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1070
+ r"""
1071
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
1072
+ [CogVideoX](https://github.com/THUDM/CogVideo).
1073
+
1074
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1075
+ for all models (such as downloading or saving).
1076
+
1077
+ Parameters:
1078
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1079
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1080
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1081
+ Tuple of downsample block types.
1082
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1083
+ Tuple of upsample block types.
1084
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1085
+ Tuple of block output channels.
1086
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1087
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1088
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
1089
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
1090
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
1091
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
1092
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
1093
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
1094
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
1095
+ force_upcast (`bool`, *optional*, default to `True`):
1096
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1097
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
1098
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1099
+ """
1100
+
1101
+ _supports_gradient_checkpointing = True
1102
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
1103
+
1104
+ @register_to_config
1105
+ def __init__(
1106
+ self,
1107
+ in_channels: int = 3,
1108
+ out_channels: int = 3,
1109
+ down_block_types: Tuple[str] = (
1110
+ "CogVideoXDownBlock3D",
1111
+ "CogVideoXDownBlock3D",
1112
+ "CogVideoXDownBlock3D",
1113
+ "CogVideoXDownBlock3D",
1114
+ ),
1115
+ up_block_types: Tuple[str] = (
1116
+ "CogVideoXUpBlock3D",
1117
+ "CogVideoXUpBlock3D",
1118
+ "CogVideoXUpBlock3D",
1119
+ "CogVideoXUpBlock3D",
1120
+ ),
1121
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
1122
+ latent_channels: int = 16,
1123
+ layers_per_block: int = 3,
1124
+ act_fn: str = "silu",
1125
+ norm_eps: float = 1e-6,
1126
+ norm_num_groups: int = 32,
1127
+ temporal_compression_ratio: float = 4,
1128
+ sample_height: int = 480,
1129
+ sample_width: int = 720,
1130
+ scaling_factor: float = 1.15258426,
1131
+ shift_factor: Optional[float] = None,
1132
+ latents_mean: Optional[Tuple[float]] = None,
1133
+ latents_std: Optional[Tuple[float]] = None,
1134
+ force_upcast: float = True,
1135
+ use_quant_conv: bool = False,
1136
+ use_post_quant_conv: bool = False,
1137
+ invert_scale_latents: bool = False,
1138
+ ):
1139
+ super().__init__()
1140
+
1141
+ self.encoder = CogVideoXEncoder3D(
1142
+ in_channels=in_channels,
1143
+ out_channels=latent_channels,
1144
+ down_block_types=down_block_types,
1145
+ block_out_channels=block_out_channels,
1146
+ layers_per_block=layers_per_block,
1147
+ act_fn=act_fn,
1148
+ norm_eps=norm_eps,
1149
+ norm_num_groups=norm_num_groups,
1150
+ temporal_compression_ratio=temporal_compression_ratio,
1151
+ )
1152
+ self.decoder = CogVideoXDecoder3D(
1153
+ in_channels=latent_channels,
1154
+ out_channels=out_channels,
1155
+ up_block_types=up_block_types,
1156
+ block_out_channels=block_out_channels,
1157
+ layers_per_block=layers_per_block,
1158
+ act_fn=act_fn,
1159
+ norm_eps=norm_eps,
1160
+ norm_num_groups=norm_num_groups,
1161
+ temporal_compression_ratio=temporal_compression_ratio,
1162
+ )
1163
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
1164
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
1165
+
1166
+ self.use_slicing = False
1167
+ self.use_tiling = False
1168
+ self.auto_split_process = False
1169
+
1170
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
1171
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
1172
+ # If you decode X latent frames together, the number of output frames is:
1173
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
1174
+ #
1175
+ # Example with num_latent_frames_batch_size = 2:
1176
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
1177
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1178
+ # => 6 * 8 = 48 frames
1179
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
1180
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
1181
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1182
+ # => 1 * 9 + 5 * 8 = 49 frames
1183
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
1184
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1185
+ # number of temporal frames.
1186
+ self.num_latent_frames_batch_size = 2
1187
+ self.num_sample_frames_batch_size = 8
1188
+
1189
+ # We make the minimum height and width of sample for tiling half that of the generally supported
1190
+ self.tile_sample_min_height = sample_height // 2
1191
+ self.tile_sample_min_width = sample_width // 2
1192
+ self.tile_latent_min_height = int(
1193
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1194
+ )
1195
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1196
+
1197
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1198
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1199
+ # and so the tiling implementation has only been tested on those specific resolutions.
1200
+ self.tile_overlap_factor_height = 1 / 6
1201
+ self.tile_overlap_factor_width = 1 / 5
1202
+
1203
+ def _set_gradient_checkpointing(self, module, value=False):
1204
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1205
+ module.gradient_checkpointing = value
1206
+
1207
+ def enable_tiling(
1208
+ self,
1209
+ tile_sample_min_height: Optional[int] = None,
1210
+ tile_sample_min_width: Optional[int] = None,
1211
+ tile_overlap_factor_height: Optional[float] = None,
1212
+ tile_overlap_factor_width: Optional[float] = None,
1213
+ ) -> None:
1214
+ r"""
1215
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1216
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1217
+ processing larger images.
1218
+
1219
+ Args:
1220
+ tile_sample_min_height (`int`, *optional*):
1221
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1222
+ tile_sample_min_width (`int`, *optional*):
1223
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1224
+ tile_overlap_factor_height (`int`, *optional*):
1225
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1226
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1227
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1228
+ tile_overlap_factor_width (`int`, *optional*):
1229
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1230
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1231
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1232
+ """
1233
+ self.use_tiling = True
1234
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1235
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1236
+ self.tile_latent_min_height = int(
1237
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1238
+ )
1239
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1240
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1241
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1242
+
1243
+ def disable_tiling(self) -> None:
1244
+ r"""
1245
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1246
+ decoding in one step.
1247
+ """
1248
+ self.use_tiling = False
1249
+
1250
+ def enable_slicing(self) -> None:
1251
+ r"""
1252
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1253
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1254
+ """
1255
+ self.use_slicing = True
1256
+
1257
+ def disable_slicing(self) -> None:
1258
+ r"""
1259
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1260
+ decoding in one step.
1261
+ """
1262
+ self.use_slicing = False
1263
+
1264
+ def _set_first_frame(self):
1265
+ for name, module in self.named_modules():
1266
+ if isinstance(module, CogVideoXUpsample3D):
1267
+ module.auto_split_process = False
1268
+ module.first_frame_flag = True
1269
+
1270
+ def _set_rest_frame(self):
1271
+ for name, module in self.named_modules():
1272
+ if isinstance(module, CogVideoXUpsample3D):
1273
+ module.auto_split_process = False
1274
+ module.first_frame_flag = False
1275
+
1276
+ def enable_auto_split_process(self) -> None:
1277
+ self.auto_split_process = True
1278
+ for name, module in self.named_modules():
1279
+ if isinstance(module, CogVideoXUpsample3D):
1280
+ module.auto_split_process = True
1281
+
1282
+ def disable_auto_split_process(self) -> None:
1283
+ self.auto_split_process = False
1284
+
1285
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1286
+ batch_size, num_channels, num_frames, height, width = x.shape
1287
+
1288
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1289
+ return self.tiled_encode(x)
1290
+
1291
+ frame_batch_size = self.num_sample_frames_batch_size
1292
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1293
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1294
+ num_batches = max(num_frames // frame_batch_size, 1)
1295
+ conv_cache = None
1296
+ enc = []
1297
+
1298
+ for i in range(num_batches):
1299
+ remaining_frames = num_frames % frame_batch_size
1300
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1301
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1302
+ x_intermediate = x[:, :, start_frame:end_frame]
1303
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1304
+ if self.quant_conv is not None:
1305
+ x_intermediate = self.quant_conv(x_intermediate)
1306
+ enc.append(x_intermediate)
1307
+
1308
+ enc = torch.cat(enc, dim=2)
1309
+ return enc
1310
+
1311
+ @apply_forward_hook
1312
+ def encode(
1313
+ self, x: torch.Tensor, return_dict: bool = True
1314
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1315
+ """
1316
+ Encode a batch of images into latents.
1317
+
1318
+ Args:
1319
+ x (`torch.Tensor`): Input batch of images.
1320
+ return_dict (`bool`, *optional*, defaults to `True`):
1321
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1322
+
1323
+ Returns:
1324
+ The latent representations of the encoded videos. If `return_dict` is True, a
1325
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1326
+ """
1327
+ if self.use_slicing and x.shape[0] > 1:
1328
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1329
+ h = torch.cat(encoded_slices)
1330
+ else:
1331
+ h = self._encode(x)
1332
+
1333
+ posterior = DiagonalGaussianDistribution(h)
1334
+
1335
+ if not return_dict:
1336
+ return (posterior,)
1337
+ return AutoencoderKLOutput(latent_dist=posterior)
1338
+
1339
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1340
+ batch_size, num_channels, num_frames, height, width = z.shape
1341
+
1342
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1343
+ return self.tiled_decode(z, return_dict=return_dict)
1344
+
1345
+ if self.auto_split_process:
1346
+ frame_batch_size = self.num_latent_frames_batch_size
1347
+ num_batches = max(num_frames // frame_batch_size, 1)
1348
+ conv_cache = None
1349
+ dec = []
1350
+
1351
+ for i in range(num_batches):
1352
+ remaining_frames = num_frames % frame_batch_size
1353
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1354
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1355
+ z_intermediate = z[:, :, start_frame:end_frame]
1356
+ if self.post_quant_conv is not None:
1357
+ z_intermediate = self.post_quant_conv(z_intermediate)
1358
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1359
+ dec.append(z_intermediate)
1360
+ else:
1361
+ conv_cache = None
1362
+ start_frame = 0
1363
+ end_frame = 1
1364
+ dec = []
1365
+
1366
+ self._set_first_frame()
1367
+ z_intermediate = z[:, :, start_frame:end_frame]
1368
+ if self.post_quant_conv is not None:
1369
+ z_intermediate = self.post_quant_conv(z_intermediate)
1370
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1371
+ dec.append(z_intermediate)
1372
+
1373
+ self._set_rest_frame()
1374
+ start_frame = end_frame
1375
+ end_frame += self.num_latent_frames_batch_size
1376
+
1377
+ while start_frame < num_frames:
1378
+ z_intermediate = z[:, :, start_frame:end_frame]
1379
+ if self.post_quant_conv is not None:
1380
+ z_intermediate = self.post_quant_conv(z_intermediate)
1381
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1382
+ dec.append(z_intermediate)
1383
+ start_frame = end_frame
1384
+ end_frame += self.num_latent_frames_batch_size
1385
+
1386
+ dec = torch.cat(dec, dim=2)
1387
+
1388
+ if not return_dict:
1389
+ return (dec,)
1390
+
1391
+ return DecoderOutput(sample=dec)
1392
+
1393
+ @apply_forward_hook
1394
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1395
+ """
1396
+ Decode a batch of images.
1397
+
1398
+ Args:
1399
+ z (`torch.Tensor`): Input batch of latent vectors.
1400
+ return_dict (`bool`, *optional*, defaults to `True`):
1401
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1402
+
1403
+ Returns:
1404
+ [`~models.vae.DecoderOutput`] or `tuple`:
1405
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1406
+ returned.
1407
+ """
1408
+ if self.use_slicing and z.shape[0] > 1:
1409
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1410
+ decoded = torch.cat(decoded_slices)
1411
+ else:
1412
+ decoded = self._decode(z).sample
1413
+
1414
+ if not return_dict:
1415
+ return (decoded,)
1416
+ return DecoderOutput(sample=decoded)
1417
+
1418
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1419
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1420
+ for y in range(blend_extent):
1421
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1422
+ y / blend_extent
1423
+ )
1424
+ return b
1425
+
1426
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1427
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1428
+ for x in range(blend_extent):
1429
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1430
+ x / blend_extent
1431
+ )
1432
+ return b
1433
+
1434
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1435
+ r"""Encode a batch of images using a tiled encoder.
1436
+
1437
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1438
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1439
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1440
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1441
+ output, but they should be much less noticeable.
1442
+
1443
+ Args:
1444
+ x (`torch.Tensor`): Input batch of videos.
1445
+
1446
+ Returns:
1447
+ `torch.Tensor`:
1448
+ The latent representation of the encoded videos.
1449
+ """
1450
+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1451
+ batch_size, num_channels, num_frames, height, width = x.shape
1452
+
1453
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1454
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1455
+ blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1456
+ blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1457
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
1458
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
1459
+ frame_batch_size = self.num_sample_frames_batch_size
1460
+
1461
+ # Split x into overlapping tiles and encode them separately.
1462
+ # The tiles have an overlap to avoid seams between tiles.
1463
+ rows = []
1464
+ for i in range(0, height, overlap_height):
1465
+ row = []
1466
+ for j in range(0, width, overlap_width):
1467
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1468
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1469
+ num_batches = max(num_frames // frame_batch_size, 1)
1470
+ conv_cache = None
1471
+ time = []
1472
+
1473
+ for k in range(num_batches):
1474
+ remaining_frames = num_frames % frame_batch_size
1475
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1476
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1477
+ tile = x[
1478
+ :,
1479
+ :,
1480
+ start_frame:end_frame,
1481
+ i : i + self.tile_sample_min_height,
1482
+ j : j + self.tile_sample_min_width,
1483
+ ]
1484
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1485
+ if self.quant_conv is not None:
1486
+ tile = self.quant_conv(tile)
1487
+ time.append(tile)
1488
+
1489
+ row.append(torch.cat(time, dim=2))
1490
+ rows.append(row)
1491
+
1492
+ result_rows = []
1493
+ for i, row in enumerate(rows):
1494
+ result_row = []
1495
+ for j, tile in enumerate(row):
1496
+ # blend the above tile and the left tile
1497
+ # to the current tile and add the current tile to the result row
1498
+ if i > 0:
1499
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1500
+ if j > 0:
1501
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1502
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1503
+ result_rows.append(torch.cat(result_row, dim=4))
1504
+
1505
+ enc = torch.cat(result_rows, dim=3)
1506
+ return enc
1507
+
1508
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1509
+ r"""
1510
+ Decode a batch of images using a tiled decoder.
1511
+
1512
+ Args:
1513
+ z (`torch.Tensor`): Input batch of latent vectors.
1514
+ return_dict (`bool`, *optional*, defaults to `True`):
1515
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1516
+
1517
+ Returns:
1518
+ [`~models.vae.DecoderOutput`] or `tuple`:
1519
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1520
+ returned.
1521
+ """
1522
+ # Rough memory assessment:
1523
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1524
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1525
+ # - Assume fp16 (2 bytes per value).
1526
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1527
+ #
1528
+ # Memory assessment when using tiling:
1529
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1530
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1531
+
1532
+ batch_size, num_channels, num_frames, height, width = z.shape
1533
+
1534
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1535
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1536
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1537
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1538
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1539
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1540
+ frame_batch_size = self.num_latent_frames_batch_size
1541
+
1542
+ # Split z into overlapping tiles and decode them separately.
1543
+ # The tiles have an overlap to avoid seams between tiles.
1544
+ rows = []
1545
+ for i in range(0, height, overlap_height):
1546
+ row = []
1547
+ for j in range(0, width, overlap_width):
1548
+ if self.auto_split_process:
1549
+ num_batches = max(num_frames // frame_batch_size, 1)
1550
+ conv_cache = None
1551
+ time = []
1552
+
1553
+ for k in range(num_batches):
1554
+ remaining_frames = num_frames % frame_batch_size
1555
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1556
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1557
+ tile = z[
1558
+ :,
1559
+ :,
1560
+ start_frame:end_frame,
1561
+ i : i + self.tile_latent_min_height,
1562
+ j : j + self.tile_latent_min_width,
1563
+ ]
1564
+ if self.post_quant_conv is not None:
1565
+ tile = self.post_quant_conv(tile)
1566
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1567
+ time.append(tile)
1568
+
1569
+ row.append(torch.cat(time, dim=2))
1570
+ else:
1571
+ conv_cache = None
1572
+ start_frame = 0
1573
+ end_frame = 1
1574
+ dec = []
1575
+
1576
+ tile = z[
1577
+ :,
1578
+ :,
1579
+ start_frame:end_frame,
1580
+ i : i + self.tile_latent_min_height,
1581
+ j : j + self.tile_latent_min_width,
1582
+ ]
1583
+
1584
+ self._set_first_frame()
1585
+ if self.post_quant_conv is not None:
1586
+ tile = self.post_quant_conv(tile)
1587
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1588
+ dec.append(tile)
1589
+
1590
+ self._set_rest_frame()
1591
+ start_frame = end_frame
1592
+ end_frame += self.num_latent_frames_batch_size
1593
+
1594
+ while start_frame < num_frames:
1595
+ tile = z[
1596
+ :,
1597
+ :,
1598
+ start_frame:end_frame,
1599
+ i : i + self.tile_latent_min_height,
1600
+ j : j + self.tile_latent_min_width,
1601
+ ]
1602
+ if self.post_quant_conv is not None:
1603
+ tile = self.post_quant_conv(tile)
1604
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1605
+ dec.append(tile)
1606
+ start_frame = end_frame
1607
+ end_frame += self.num_latent_frames_batch_size
1608
+
1609
+ row.append(torch.cat(dec, dim=2))
1610
+ rows.append(row)
1611
+
1612
+ result_rows = []
1613
+ for i, row in enumerate(rows):
1614
+ result_row = []
1615
+ for j, tile in enumerate(row):
1616
+ # blend the above tile and the left tile
1617
+ # to the current tile and add the current tile to the result row
1618
+ if i > 0:
1619
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1620
+ if j > 0:
1621
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1622
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1623
+ result_rows.append(torch.cat(result_row, dim=4))
1624
+
1625
+ dec = torch.cat(result_rows, dim=3)
1626
+
1627
+ if not return_dict:
1628
+ return (dec,)
1629
+
1630
+ return DecoderOutput(sample=dec)
1631
+
1632
+ def forward(
1633
+ self,
1634
+ sample: torch.Tensor,
1635
+ sample_posterior: bool = False,
1636
+ return_dict: bool = True,
1637
+ generator: Optional[torch.Generator] = None,
1638
+ ) -> Union[torch.Tensor, torch.Tensor]:
1639
+ x = sample
1640
+ posterior = self.encode(x).latent_dist
1641
+ if sample_posterior:
1642
+ z = posterior.sample(generator=generator)
1643
+ else:
1644
+ z = posterior.mode()
1645
+ dec = self.decode(z)
1646
+ if not return_dict:
1647
+ return (dec,)
1648
+ return dec
1649
+
1650
+ @classmethod
1651
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
1652
+ if subfolder is not None:
1653
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1654
+
1655
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1656
+ if not os.path.isfile(config_file):
1657
+ raise RuntimeError(f"{config_file} does not exist")
1658
+ with open(config_file, "r") as f:
1659
+ config = json.load(f)
1660
+
1661
+ model = cls.from_config(config, **vae_additional_kwargs)
1662
+ from diffusers.utils import WEIGHTS_NAME
1663
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1664
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1665
+ if os.path.exists(model_file_safetensors):
1666
+ from safetensors.torch import load_file, safe_open
1667
+ state_dict = load_file(model_file_safetensors)
1668
+ else:
1669
+ if not os.path.isfile(model_file):
1670
+ raise RuntimeError(f"{model_file} does not exist")
1671
+ state_dict = torch.load(model_file, map_location="cpu")
1672
+ m, u = model.load_state_dict(state_dict, strict=False)
1673
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1674
+ print(m, u)
1675
+ return model
videox_fun/pipeline/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .pipeline_cogvideox_fun import CogVideoXFunPipeline
2
+ from .pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
videox_fun/pipeline/pipeline_cogvideox_fun.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
27
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.video_processor import VideoProcessor
30
+
31
+ from ..models import (AutoencoderKLCogVideoX,
32
+ CogVideoXTransformer3DModel, T5EncoderModel,
33
+ T5Tokenizer)
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ EXAMPLE_DOC_STRING = """
39
+ Examples:
40
+ ```python
41
+ pass
42
+ ```
43
+ """
44
+
45
+
46
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
47
+ def get_3d_rotary_pos_embed(
48
+ embed_dim,
49
+ crops_coords,
50
+ grid_size,
51
+ temporal_size,
52
+ theta: int = 10000,
53
+ use_real: bool = True,
54
+ grid_type: str = "linspace",
55
+ max_size: Optional[Tuple[int, int]] = None,
56
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
57
+ """
58
+ RoPE for video tokens with 3D structure.
59
+
60
+ Args:
61
+ embed_dim: (`int`):
62
+ The embedding dimension size, corresponding to hidden_size_head.
63
+ crops_coords (`Tuple[int]`):
64
+ The top-left and bottom-right coordinates of the crop.
65
+ grid_size (`Tuple[int]`):
66
+ The grid size of the spatial positional embedding (height, width).
67
+ temporal_size (`int`):
68
+ The size of the temporal dimension.
69
+ theta (`float`):
70
+ Scaling factor for frequency computation.
71
+ grid_type (`str`):
72
+ Whether to use "linspace" or "slice" to compute grids.
73
+
74
+ Returns:
75
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
76
+ """
77
+ if use_real is not True:
78
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
79
+
80
+ if grid_type == "linspace":
81
+ start, stop = crops_coords
82
+ grid_size_h, grid_size_w = grid_size
83
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
84
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
85
+ grid_t = np.arange(temporal_size, dtype=np.float32)
86
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
87
+ elif grid_type == "slice":
88
+ max_h, max_w = max_size
89
+ grid_size_h, grid_size_w = grid_size
90
+ grid_h = np.arange(max_h, dtype=np.float32)
91
+ grid_w = np.arange(max_w, dtype=np.float32)
92
+ grid_t = np.arange(temporal_size, dtype=np.float32)
93
+ else:
94
+ raise ValueError("Invalid value passed for `grid_type`.")
95
+
96
+ # Compute dimensions for each axis
97
+ dim_t = embed_dim // 4
98
+ dim_h = embed_dim // 8 * 3
99
+ dim_w = embed_dim // 8 * 3
100
+
101
+ # Temporal frequencies
102
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
103
+ # Spatial frequencies for height and width
104
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
105
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
106
+
107
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
108
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
109
+ freqs_t = freqs_t[:, None, None, :].expand(
110
+ -1, grid_size_h, grid_size_w, -1
111
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
112
+ freqs_h = freqs_h[None, :, None, :].expand(
113
+ temporal_size, -1, grid_size_w, -1
114
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
115
+ freqs_w = freqs_w[None, None, :, :].expand(
116
+ temporal_size, grid_size_h, -1, -1
117
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
118
+
119
+ freqs = torch.cat(
120
+ [freqs_t, freqs_h, freqs_w], dim=-1
121
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
122
+ freqs = freqs.view(
123
+ temporal_size * grid_size_h * grid_size_w, -1
124
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
125
+ return freqs
126
+
127
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
128
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
129
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
130
+
131
+ if grid_type == "slice":
132
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
133
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
134
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
135
+
136
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
137
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
138
+ return cos, sin
139
+
140
+
141
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
142
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
143
+ tw = tgt_width
144
+ th = tgt_height
145
+ h, w = src
146
+ r = h / w
147
+ if r > (th / tw):
148
+ resize_height = th
149
+ resize_width = int(round(th / h * w))
150
+ else:
151
+ resize_width = tw
152
+ resize_height = int(round(tw / w * h))
153
+
154
+ crop_top = int(round((th - resize_height) / 2.0))
155
+ crop_left = int(round((tw - resize_width) / 2.0))
156
+
157
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
158
+
159
+
160
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
161
+ def retrieve_timesteps(
162
+ scheduler,
163
+ num_inference_steps: Optional[int] = None,
164
+ device: Optional[Union[str, torch.device]] = None,
165
+ timesteps: Optional[List[int]] = None,
166
+ sigmas: Optional[List[float]] = None,
167
+ **kwargs,
168
+ ):
169
+ """
170
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
171
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
172
+
173
+ Args:
174
+ scheduler (`SchedulerMixin`):
175
+ The scheduler to get timesteps from.
176
+ num_inference_steps (`int`):
177
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
178
+ must be `None`.
179
+ device (`str` or `torch.device`, *optional*):
180
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
181
+ timesteps (`List[int]`, *optional*):
182
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
183
+ `num_inference_steps` and `sigmas` must be `None`.
184
+ sigmas (`List[float]`, *optional*):
185
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
186
+ `num_inference_steps` and `timesteps` must be `None`.
187
+
188
+ Returns:
189
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
190
+ second element is the number of inference steps.
191
+ """
192
+ if timesteps is not None and sigmas is not None:
193
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
194
+ if timesteps is not None:
195
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
196
+ if not accepts_timesteps:
197
+ raise ValueError(
198
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
199
+ f" timestep schedules. Please check whether you are using the correct scheduler."
200
+ )
201
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
202
+ timesteps = scheduler.timesteps
203
+ num_inference_steps = len(timesteps)
204
+ elif sigmas is not None:
205
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
206
+ if not accept_sigmas:
207
+ raise ValueError(
208
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
209
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
210
+ )
211
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
212
+ timesteps = scheduler.timesteps
213
+ num_inference_steps = len(timesteps)
214
+ else:
215
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
216
+ timesteps = scheduler.timesteps
217
+ return timesteps, num_inference_steps
218
+
219
+
220
+ @dataclass
221
+ class CogVideoXFunPipelineOutput(BaseOutput):
222
+ r"""
223
+ Output class for CogVideo pipelines.
224
+
225
+ Args:
226
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
227
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
228
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
229
+ `(batch_size, num_frames, channels, height, width)`.
230
+ """
231
+
232
+ videos: torch.Tensor
233
+
234
+
235
+ class CogVideoXFunPipeline(DiffusionPipeline):
236
+ r"""
237
+ Pipeline for text-to-video generation using CogVideoX_Fun.
238
+
239
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
240
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
241
+
242
+ Args:
243
+ vae ([`AutoencoderKL`]):
244
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
245
+ text_encoder ([`T5EncoderModel`]):
246
+ Frozen text-encoder. CogVideoX uses
247
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
248
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
249
+ tokenizer (`T5Tokenizer`):
250
+ Tokenizer of class
251
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
252
+ transformer ([`CogVideoXTransformer3DModel`]):
253
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
254
+ scheduler ([`SchedulerMixin`]):
255
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
256
+ """
257
+
258
+ _optional_components = []
259
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
260
+
261
+ _callback_tensor_inputs = [
262
+ "latents",
263
+ "prompt_embeds",
264
+ "negative_prompt_embeds",
265
+ ]
266
+
267
+ def __init__(
268
+ self,
269
+ tokenizer: T5Tokenizer,
270
+ text_encoder: T5EncoderModel,
271
+ vae: AutoencoderKLCogVideoX,
272
+ transformer: CogVideoXTransformer3DModel,
273
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
274
+ ):
275
+ super().__init__()
276
+
277
+ self.register_modules(
278
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
279
+ )
280
+ self.vae_scale_factor_spatial = (
281
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
282
+ )
283
+ self.vae_scale_factor_temporal = (
284
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
285
+ )
286
+
287
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
288
+
289
+ def _get_t5_prompt_embeds(
290
+ self,
291
+ prompt: Union[str, List[str]] = None,
292
+ num_videos_per_prompt: int = 1,
293
+ max_sequence_length: int = 226,
294
+ device: Optional[torch.device] = None,
295
+ dtype: Optional[torch.dtype] = None,
296
+ ):
297
+ device = device or self._execution_device
298
+ dtype = dtype or self.text_encoder.dtype
299
+
300
+ prompt = [prompt] if isinstance(prompt, str) else prompt
301
+ batch_size = len(prompt)
302
+
303
+ text_inputs = self.tokenizer(
304
+ prompt,
305
+ padding="max_length",
306
+ max_length=max_sequence_length,
307
+ truncation=True,
308
+ add_special_tokens=True,
309
+ return_tensors="pt",
310
+ )
311
+ text_input_ids = text_inputs.input_ids
312
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
313
+
314
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
315
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
316
+ logger.warning(
317
+ "The following part of your input was truncated because `max_sequence_length` is set to "
318
+ f" {max_sequence_length} tokens: {removed_text}"
319
+ )
320
+
321
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
322
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
323
+
324
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
325
+ _, seq_len, _ = prompt_embeds.shape
326
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
327
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
328
+
329
+ return prompt_embeds
330
+
331
+ def encode_prompt(
332
+ self,
333
+ prompt: Union[str, List[str]],
334
+ negative_prompt: Optional[Union[str, List[str]]] = None,
335
+ do_classifier_free_guidance: bool = True,
336
+ num_videos_per_prompt: int = 1,
337
+ prompt_embeds: Optional[torch.Tensor] = None,
338
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
339
+ max_sequence_length: int = 226,
340
+ device: Optional[torch.device] = None,
341
+ dtype: Optional[torch.dtype] = None,
342
+ ):
343
+ r"""
344
+ Encodes the prompt into text encoder hidden states.
345
+
346
+ Args:
347
+ prompt (`str` or `List[str]`, *optional*):
348
+ prompt to be encoded
349
+ negative_prompt (`str` or `List[str]`, *optional*):
350
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
351
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
352
+ less than `1`).
353
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
354
+ Whether to use classifier free guidance or not.
355
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
356
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
357
+ prompt_embeds (`torch.Tensor`, *optional*):
358
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
359
+ provided, text embeddings will be generated from `prompt` input argument.
360
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
361
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
362
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
363
+ argument.
364
+ device: (`torch.device`, *optional*):
365
+ torch device
366
+ dtype: (`torch.dtype`, *optional*):
367
+ torch dtype
368
+ """
369
+ device = device or self._execution_device
370
+
371
+ prompt = [prompt] if isinstance(prompt, str) else prompt
372
+ if prompt is not None:
373
+ batch_size = len(prompt)
374
+ else:
375
+ batch_size = prompt_embeds.shape[0]
376
+
377
+ if prompt_embeds is None:
378
+ prompt_embeds = self._get_t5_prompt_embeds(
379
+ prompt=prompt,
380
+ num_videos_per_prompt=num_videos_per_prompt,
381
+ max_sequence_length=max_sequence_length,
382
+ device=device,
383
+ dtype=dtype,
384
+ )
385
+
386
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
387
+ negative_prompt = negative_prompt or ""
388
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
389
+
390
+ if prompt is not None and type(prompt) is not type(negative_prompt):
391
+ raise TypeError(
392
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
393
+ f" {type(prompt)}."
394
+ )
395
+ elif batch_size != len(negative_prompt):
396
+ raise ValueError(
397
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
398
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
399
+ " the batch size of `prompt`."
400
+ )
401
+
402
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
403
+ prompt=negative_prompt,
404
+ num_videos_per_prompt=num_videos_per_prompt,
405
+ max_sequence_length=max_sequence_length,
406
+ device=device,
407
+ dtype=dtype,
408
+ )
409
+
410
+ return prompt_embeds, negative_prompt_embeds
411
+
412
+ def prepare_latents(
413
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
414
+ ):
415
+ if isinstance(generator, list) and len(generator) != batch_size:
416
+ raise ValueError(
417
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
418
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
419
+ )
420
+
421
+ shape = (
422
+ batch_size,
423
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
424
+ num_channels_latents,
425
+ height // self.vae_scale_factor_spatial,
426
+ width // self.vae_scale_factor_spatial,
427
+ )
428
+
429
+ if latents is None:
430
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
431
+ else:
432
+ latents = latents.to(device)
433
+
434
+ # scale the initial noise by the standard deviation required by the scheduler
435
+ latents = latents * self.scheduler.init_noise_sigma
436
+ return latents
437
+
438
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
439
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
440
+ latents = 1 / self.vae.config.scaling_factor * latents
441
+
442
+ frames = self.vae.decode(latents).sample
443
+ frames = (frames / 2 + 0.5).clamp(0, 1)
444
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
445
+ frames = frames.cpu().float().numpy()
446
+ return frames
447
+
448
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
449
+ def prepare_extra_step_kwargs(self, generator, eta):
450
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
451
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
452
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
453
+ # and should be between [0, 1]
454
+
455
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
456
+ extra_step_kwargs = {}
457
+ if accepts_eta:
458
+ extra_step_kwargs["eta"] = eta
459
+
460
+ # check if the scheduler accepts generator
461
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
462
+ if accepts_generator:
463
+ extra_step_kwargs["generator"] = generator
464
+ return extra_step_kwargs
465
+
466
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
467
+ def check_inputs(
468
+ self,
469
+ prompt,
470
+ height,
471
+ width,
472
+ negative_prompt,
473
+ callback_on_step_end_tensor_inputs,
474
+ prompt_embeds=None,
475
+ negative_prompt_embeds=None,
476
+ ):
477
+ if height % 8 != 0 or width % 8 != 0:
478
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
479
+
480
+ if callback_on_step_end_tensor_inputs is not None and not all(
481
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
482
+ ):
483
+ raise ValueError(
484
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
485
+ )
486
+ if prompt is not None and prompt_embeds is not None:
487
+ raise ValueError(
488
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
489
+ " only forward one of the two."
490
+ )
491
+ elif prompt is None and prompt_embeds is None:
492
+ raise ValueError(
493
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
494
+ )
495
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
496
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
497
+
498
+ if prompt is not None and negative_prompt_embeds is not None:
499
+ raise ValueError(
500
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
501
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
502
+ )
503
+
504
+ if negative_prompt is not None and negative_prompt_embeds is not None:
505
+ raise ValueError(
506
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
507
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
508
+ )
509
+
510
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
511
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
512
+ raise ValueError(
513
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
514
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
515
+ f" {negative_prompt_embeds.shape}."
516
+ )
517
+
518
+ def fuse_qkv_projections(self) -> None:
519
+ r"""Enables fused QKV projections."""
520
+ self.fusing_transformer = True
521
+ self.transformer.fuse_qkv_projections()
522
+
523
+ def unfuse_qkv_projections(self) -> None:
524
+ r"""Disable QKV projection fusion if enabled."""
525
+ if not self.fusing_transformer:
526
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
527
+ else:
528
+ self.transformer.unfuse_qkv_projections()
529
+ self.fusing_transformer = False
530
+
531
+ def _prepare_rotary_positional_embeddings(
532
+ self,
533
+ height: int,
534
+ width: int,
535
+ num_frames: int,
536
+ device: torch.device,
537
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
538
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
539
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
540
+
541
+ p = self.transformer.config.patch_size
542
+ p_t = self.transformer.config.patch_size_t
543
+
544
+ base_size_width = self.transformer.config.sample_width // p
545
+ base_size_height = self.transformer.config.sample_height // p
546
+
547
+ if p_t is None:
548
+ # CogVideoX 1.0
549
+ grid_crops_coords = get_resize_crop_region_for_grid(
550
+ (grid_height, grid_width), base_size_width, base_size_height
551
+ )
552
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
553
+ embed_dim=self.transformer.config.attention_head_dim,
554
+ crops_coords=grid_crops_coords,
555
+ grid_size=(grid_height, grid_width),
556
+ temporal_size=num_frames,
557
+ )
558
+ else:
559
+ # CogVideoX 1.5
560
+ base_num_frames = (num_frames + p_t - 1) // p_t
561
+
562
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
563
+ embed_dim=self.transformer.config.attention_head_dim,
564
+ crops_coords=None,
565
+ grid_size=(grid_height, grid_width),
566
+ temporal_size=base_num_frames,
567
+ grid_type="slice",
568
+ max_size=(base_size_height, base_size_width),
569
+ )
570
+
571
+ freqs_cos = freqs_cos.to(device=device)
572
+ freqs_sin = freqs_sin.to(device=device)
573
+ return freqs_cos, freqs_sin
574
+
575
+ @property
576
+ def guidance_scale(self):
577
+ return self._guidance_scale
578
+
579
+ @property
580
+ def num_timesteps(self):
581
+ return self._num_timesteps
582
+
583
+ @property
584
+ def attention_kwargs(self):
585
+ return self._attention_kwargs
586
+
587
+ @property
588
+ def interrupt(self):
589
+ return self._interrupt
590
+
591
+ @torch.no_grad()
592
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
593
+ def __call__(
594
+ self,
595
+ prompt: Optional[Union[str, List[str]]] = None,
596
+ negative_prompt: Optional[Union[str, List[str]]] = None,
597
+ height: int = 480,
598
+ width: int = 720,
599
+ num_frames: int = 49,
600
+ num_inference_steps: int = 50,
601
+ timesteps: Optional[List[int]] = None,
602
+ guidance_scale: float = 6,
603
+ use_dynamic_cfg: bool = False,
604
+ num_videos_per_prompt: int = 1,
605
+ eta: float = 0.0,
606
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
607
+ latents: Optional[torch.FloatTensor] = None,
608
+ prompt_embeds: Optional[torch.FloatTensor] = None,
609
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
610
+ output_type: str = "numpy",
611
+ return_dict: bool = False,
612
+ callback_on_step_end: Optional[
613
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
614
+ ] = None,
615
+ attention_kwargs: Optional[Dict[str, Any]] = None,
616
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
617
+ max_sequence_length: int = 226,
618
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
619
+ """
620
+ Function invoked when calling the pipeline for generation.
621
+
622
+ Args:
623
+ prompt (`str` or `List[str]`, *optional*):
624
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
625
+ instead.
626
+ negative_prompt (`str` or `List[str]`, *optional*):
627
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
628
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
629
+ less than `1`).
630
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
631
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
632
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
633
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
634
+ num_frames (`int`, defaults to `48`):
635
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
636
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
637
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
638
+ needs to be satisfied is that of divisibility mentioned above.
639
+ num_inference_steps (`int`, *optional*, defaults to 50):
640
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
641
+ expense of slower inference.
642
+ timesteps (`List[int]`, *optional*):
643
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
644
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
645
+ passed will be used. Must be in descending order.
646
+ guidance_scale (`float`, *optional*, defaults to 7.0):
647
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
648
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
649
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
650
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
651
+ usually at the expense of lower image quality.
652
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
653
+ The number of videos to generate per prompt.
654
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
655
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
656
+ to make generation deterministic.
657
+ latents (`torch.FloatTensor`, *optional*):
658
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
659
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
660
+ tensor will ge generated by sampling using the supplied random `generator`.
661
+ prompt_embeds (`torch.FloatTensor`, *optional*):
662
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
663
+ provided, text embeddings will be generated from `prompt` input argument.
664
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
665
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
666
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
667
+ argument.
668
+ output_type (`str`, *optional*, defaults to `"pil"`):
669
+ The output format of the generate image. Choose between
670
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
671
+ return_dict (`bool`, *optional*, defaults to `True`):
672
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
673
+ of a plain tuple.
674
+ callback_on_step_end (`Callable`, *optional*):
675
+ A function that calls at the end of each denoising steps during the inference. The function is called
676
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
677
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
678
+ `callback_on_step_end_tensor_inputs`.
679
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
680
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
681
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
682
+ `._callback_tensor_inputs` attribute of your pipeline class.
683
+ max_sequence_length (`int`, defaults to `226`):
684
+ Maximum sequence length in encoded prompt. Must be consistent with
685
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
686
+
687
+ Examples:
688
+
689
+ Returns:
690
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
691
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
692
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
693
+ """
694
+
695
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
696
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
697
+
698
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
699
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
700
+ num_frames = num_frames or self.transformer.config.sample_frames
701
+
702
+ num_videos_per_prompt = 1
703
+
704
+ # 1. Check inputs. Raise error if not correct
705
+ self.check_inputs(
706
+ prompt,
707
+ height,
708
+ width,
709
+ negative_prompt,
710
+ callback_on_step_end_tensor_inputs,
711
+ prompt_embeds,
712
+ negative_prompt_embeds,
713
+ )
714
+ self._guidance_scale = guidance_scale
715
+ self._attention_kwargs = attention_kwargs
716
+ self._interrupt = False
717
+
718
+ # 2. Default call parameters
719
+ if prompt is not None and isinstance(prompt, str):
720
+ batch_size = 1
721
+ elif prompt is not None and isinstance(prompt, list):
722
+ batch_size = len(prompt)
723
+ else:
724
+ batch_size = prompt_embeds.shape[0]
725
+
726
+ device = self._execution_device
727
+
728
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
729
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
730
+ # corresponds to doing no classifier free guidance.
731
+ do_classifier_free_guidance = guidance_scale > 1.0
732
+
733
+ # 3. Encode input prompt
734
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
735
+ prompt,
736
+ negative_prompt,
737
+ do_classifier_free_guidance,
738
+ num_videos_per_prompt=num_videos_per_prompt,
739
+ prompt_embeds=prompt_embeds,
740
+ negative_prompt_embeds=negative_prompt_embeds,
741
+ max_sequence_length=max_sequence_length,
742
+ device=device,
743
+ )
744
+ if do_classifier_free_guidance:
745
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
746
+
747
+ # 4. Prepare timesteps
748
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
749
+ self._num_timesteps = len(timesteps)
750
+
751
+ # 5. Prepare latents
752
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
753
+
754
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
755
+ patch_size_t = self.transformer.config.patch_size_t
756
+ additional_frames = 0
757
+ if num_frames != 1 and patch_size_t is not None and latent_frames % patch_size_t != 0:
758
+ additional_frames = patch_size_t - latent_frames % patch_size_t
759
+ num_frames += additional_frames * self.vae_scale_factor_temporal
760
+
761
+ latent_channels = self.transformer.config.in_channels
762
+ latents = self.prepare_latents(
763
+ batch_size * num_videos_per_prompt,
764
+ latent_channels,
765
+ num_frames,
766
+ height,
767
+ width,
768
+ prompt_embeds.dtype,
769
+ device,
770
+ generator,
771
+ latents,
772
+ )
773
+
774
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
775
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
776
+
777
+ # 7. Create rotary embeds if required
778
+ image_rotary_emb = (
779
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
780
+ if self.transformer.config.use_rotary_positional_embeddings
781
+ else None
782
+ )
783
+
784
+ # 8. Denoising loop
785
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
786
+
787
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
788
+ # for DPM-solver++
789
+ old_pred_original_sample = None
790
+ for i, t in enumerate(timesteps):
791
+ if self.interrupt:
792
+ continue
793
+
794
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
795
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
796
+
797
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
798
+ timestep = t.expand(latent_model_input.shape[0])
799
+
800
+ # predict noise model_output
801
+ noise_pred = self.transformer(
802
+ hidden_states=latent_model_input,
803
+ encoder_hidden_states=prompt_embeds,
804
+ timestep=timestep,
805
+ image_rotary_emb=image_rotary_emb,
806
+ return_dict=False,
807
+ )[0]
808
+ noise_pred = noise_pred.float()
809
+
810
+ # perform guidance
811
+ if use_dynamic_cfg:
812
+ self._guidance_scale = 1 + guidance_scale * (
813
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
814
+ )
815
+ if do_classifier_free_guidance:
816
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
817
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
818
+
819
+ # compute the previous noisy sample x_t -> x_t-1
820
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
821
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
822
+ else:
823
+ latents, old_pred_original_sample = self.scheduler.step(
824
+ noise_pred,
825
+ old_pred_original_sample,
826
+ t,
827
+ timesteps[i - 1] if i > 0 else None,
828
+ latents,
829
+ **extra_step_kwargs,
830
+ return_dict=False,
831
+ )
832
+ latents = latents.to(prompt_embeds.dtype)
833
+
834
+ # call the callback, if provided
835
+ if callback_on_step_end is not None:
836
+ callback_kwargs = {}
837
+ for k in callback_on_step_end_tensor_inputs:
838
+ callback_kwargs[k] = locals()[k]
839
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
840
+
841
+ latents = callback_outputs.pop("latents", latents)
842
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
843
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
844
+
845
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
846
+ progress_bar.update()
847
+
848
+ if output_type == "numpy":
849
+ video = self.decode_latents(latents)
850
+ elif not output_type == "latent":
851
+ video = self.decode_latents(latents)
852
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
853
+ else:
854
+ video = latents
855
+
856
+ # Offload all models
857
+ self.maybe_free_model_hooks()
858
+
859
+ if not return_dict:
860
+ video = torch.from_numpy(video)
861
+
862
+ return CogVideoXFunPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py ADDED
@@ -0,0 +1,1244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from einops import rearrange
33
+
34
+ from ..models import (AutoencoderKLCogVideoX,
35
+ CogVideoXTransformer3DModel, T5EncoderModel,
36
+ T5Tokenizer)
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```python
44
+ pass
45
+ ```
46
+ """
47
+
48
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
49
+ def get_3d_rotary_pos_embed(
50
+ embed_dim,
51
+ crops_coords,
52
+ grid_size,
53
+ temporal_size,
54
+ theta: int = 10000,
55
+ use_real: bool = True,
56
+ grid_type: str = "linspace",
57
+ max_size: Optional[Tuple[int, int]] = None,
58
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
59
+ """
60
+ RoPE for video tokens with 3D structure.
61
+
62
+ Args:
63
+ embed_dim: (`int`):
64
+ The embedding dimension size, corresponding to hidden_size_head.
65
+ crops_coords (`Tuple[int]`):
66
+ The top-left and bottom-right coordinates of the crop.
67
+ grid_size (`Tuple[int]`):
68
+ The grid size of the spatial positional embedding (height, width).
69
+ temporal_size (`int`):
70
+ The size of the temporal dimension.
71
+ theta (`float`):
72
+ Scaling factor for frequency computation.
73
+ grid_type (`str`):
74
+ Whether to use "linspace" or "slice" to compute grids.
75
+
76
+ Returns:
77
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
78
+ """
79
+ if use_real is not True:
80
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
81
+
82
+ if grid_type == "linspace":
83
+ start, stop = crops_coords
84
+ grid_size_h, grid_size_w = grid_size
85
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
86
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
87
+ grid_t = np.arange(temporal_size, dtype=np.float32)
88
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
89
+ elif grid_type == "slice":
90
+ max_h, max_w = max_size
91
+ grid_size_h, grid_size_w = grid_size
92
+ grid_h = np.arange(max_h, dtype=np.float32)
93
+ grid_w = np.arange(max_w, dtype=np.float32)
94
+ grid_t = np.arange(temporal_size, dtype=np.float32)
95
+ else:
96
+ raise ValueError("Invalid value passed for `grid_type`.")
97
+
98
+ # Compute dimensions for each axis
99
+ dim_t = embed_dim // 4
100
+ dim_h = embed_dim // 8 * 3
101
+ dim_w = embed_dim // 8 * 3
102
+
103
+ # Temporal frequencies
104
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
105
+ # Spatial frequencies for height and width
106
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
107
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
108
+
109
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
110
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
111
+ freqs_t = freqs_t[:, None, None, :].expand(
112
+ -1, grid_size_h, grid_size_w, -1
113
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
114
+ freqs_h = freqs_h[None, :, None, :].expand(
115
+ temporal_size, -1, grid_size_w, -1
116
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
117
+ freqs_w = freqs_w[None, None, :, :].expand(
118
+ temporal_size, grid_size_h, -1, -1
119
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
120
+
121
+ freqs = torch.cat(
122
+ [freqs_t, freqs_h, freqs_w], dim=-1
123
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
124
+ freqs = freqs.view(
125
+ temporal_size * grid_size_h * grid_size_w, -1
126
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
127
+ return freqs
128
+
129
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
130
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
131
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
132
+
133
+ if grid_type == "slice":
134
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
135
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
136
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
137
+
138
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
139
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
140
+ return cos, sin
141
+
142
+
143
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
144
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
145
+ tw = tgt_width
146
+ th = tgt_height
147
+ h, w = src
148
+ r = h / w
149
+ if r > (th / tw):
150
+ resize_height = th
151
+ resize_width = int(round(th / h * w))
152
+ else:
153
+ resize_width = tw
154
+ resize_height = int(round(tw / w * h))
155
+
156
+ crop_top = int(round((th - resize_height) / 2.0))
157
+ crop_left = int(round((tw - resize_width) / 2.0))
158
+
159
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
160
+
161
+
162
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
163
+ def retrieve_timesteps(
164
+ scheduler,
165
+ num_inference_steps: Optional[int] = None,
166
+ device: Optional[Union[str, torch.device]] = None,
167
+ timesteps: Optional[List[int]] = None,
168
+ sigmas: Optional[List[float]] = None,
169
+ **kwargs,
170
+ ):
171
+ """
172
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
173
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
174
+
175
+ Args:
176
+ scheduler (`SchedulerMixin`):
177
+ The scheduler to get timesteps from.
178
+ num_inference_steps (`int`):
179
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
180
+ must be `None`.
181
+ device (`str` or `torch.device`, *optional*):
182
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
183
+ timesteps (`List[int]`, *optional*):
184
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
185
+ `num_inference_steps` and `sigmas` must be `None`.
186
+ sigmas (`List[float]`, *optional*):
187
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
188
+ `num_inference_steps` and `timesteps` must be `None`.
189
+
190
+ Returns:
191
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
192
+ second element is the number of inference steps.
193
+ """
194
+ if timesteps is not None and sigmas is not None:
195
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
196
+ if timesteps is not None:
197
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
198
+ if not accepts_timesteps:
199
+ raise ValueError(
200
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
201
+ f" timestep schedules. Please check whether you are using the correct scheduler."
202
+ )
203
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
204
+ timesteps = scheduler.timesteps
205
+ num_inference_steps = len(timesteps)
206
+ elif sigmas is not None:
207
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
208
+ if not accept_sigmas:
209
+ raise ValueError(
210
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
211
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
212
+ )
213
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
214
+ timesteps = scheduler.timesteps
215
+ num_inference_steps = len(timesteps)
216
+ else:
217
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
218
+ timesteps = scheduler.timesteps
219
+ return timesteps, num_inference_steps
220
+
221
+
222
+ def resize_mask(mask, latent, process_first_frame_only=True):
223
+ latent_size = latent.size()
224
+ batch_size, channels, num_frames, height, width = mask.shape
225
+
226
+ if process_first_frame_only:
227
+ target_size = list(latent_size[2:])
228
+ target_size[0] = 1
229
+ first_frame_resized = F.interpolate(
230
+ mask[:, :, 0:1, :, :],
231
+ size=target_size,
232
+ mode='trilinear',
233
+ align_corners=False
234
+ )
235
+
236
+ target_size = list(latent_size[2:])
237
+ target_size[0] = target_size[0] - 1
238
+ if target_size[0] != 0:
239
+ remaining_frames_resized = F.interpolate(
240
+ mask[:, :, 1:, :, :],
241
+ size=target_size,
242
+ mode='trilinear',
243
+ align_corners=False
244
+ )
245
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
246
+ else:
247
+ resized_mask = first_frame_resized
248
+ else:
249
+ target_size = list(latent_size[2:])
250
+ resized_mask = F.interpolate(
251
+ mask,
252
+ size=target_size,
253
+ mode='trilinear',
254
+ align_corners=False
255
+ )
256
+ return resized_mask
257
+
258
+
259
+ def add_noise_to_reference_video(image, ratio=None):
260
+ if ratio is None:
261
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
262
+ sigma = torch.exp(sigma).to(image.dtype)
263
+ else:
264
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
265
+
266
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
267
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
268
+ image = image + image_noise
269
+ return image
270
+
271
+
272
+ @dataclass
273
+ class CogVideoXFunPipelineOutput(BaseOutput):
274
+ r"""
275
+ Output class for CogVideo pipelines.
276
+
277
+ Args:
278
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
279
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
280
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
281
+ `(batch_size, num_frames, channels, height, width)`.
282
+ """
283
+
284
+ videos: torch.Tensor
285
+
286
+
287
+ class CogVideoXFunInpaintPipeline(DiffusionPipeline):
288
+ r"""
289
+ Pipeline for text-to-video generation using CogVideoX.
290
+
291
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
292
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
293
+
294
+ Args:
295
+ vae ([`AutoencoderKL`]):
296
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
297
+ text_encoder ([`T5EncoderModel`]):
298
+ Frozen text-encoder. CogVideoX_Fun uses
299
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
300
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
301
+ tokenizer (`T5Tokenizer`):
302
+ Tokenizer of class
303
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
304
+ transformer ([`CogVideoXTransformer3DModel`]):
305
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
306
+ scheduler ([`SchedulerMixin`]):
307
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
308
+ """
309
+
310
+ _optional_components = []
311
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
312
+
313
+ _callback_tensor_inputs = [
314
+ "latents",
315
+ "prompt_embeds",
316
+ "negative_prompt_embeds",
317
+ ]
318
+
319
+ def __init__(
320
+ self,
321
+ tokenizer: T5Tokenizer,
322
+ text_encoder: T5EncoderModel,
323
+ vae: AutoencoderKLCogVideoX,
324
+ transformer: CogVideoXTransformer3DModel,
325
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
326
+ ):
327
+ super().__init__()
328
+
329
+ self.register_modules(
330
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
331
+ )
332
+ self.vae_scale_factor_spatial = (
333
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
334
+ )
335
+ self.vae_scale_factor_temporal = (
336
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
337
+ )
338
+
339
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
340
+
341
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
342
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
343
+ self.mask_processor = VaeImageProcessor(
344
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=False, do_convert_grayscale=True
345
+ )
346
+
347
+ def _get_t5_prompt_embeds(
348
+ self,
349
+ prompt: Union[str, List[str]] = None,
350
+ num_videos_per_prompt: int = 1,
351
+ max_sequence_length: int = 226,
352
+ device: Optional[torch.device] = None,
353
+ dtype: Optional[torch.dtype] = None,
354
+ ):
355
+ device = device or self._execution_device
356
+ dtype = dtype or self.text_encoder.dtype
357
+
358
+ prompt = [prompt] if isinstance(prompt, str) else prompt
359
+ batch_size = len(prompt)
360
+
361
+ text_inputs = self.tokenizer(
362
+ prompt,
363
+ padding="max_length",
364
+ max_length=max_sequence_length,
365
+ truncation=True,
366
+ add_special_tokens=True,
367
+ return_tensors="pt",
368
+ )
369
+ text_input_ids = text_inputs.input_ids
370
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
371
+
372
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
373
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
374
+ logger.warning(
375
+ "The following part of your input was truncated because `max_sequence_length` is set to "
376
+ f" {max_sequence_length} tokens: {removed_text}"
377
+ )
378
+
379
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
380
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
381
+
382
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
383
+ _, seq_len, _ = prompt_embeds.shape
384
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
385
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
386
+
387
+ return prompt_embeds
388
+
389
+ def encode_prompt(
390
+ self,
391
+ prompt: Union[str, List[str]],
392
+ negative_prompt: Optional[Union[str, List[str]]] = None,
393
+ do_classifier_free_guidance: bool = True,
394
+ num_videos_per_prompt: int = 1,
395
+ prompt_embeds: Optional[torch.Tensor] = None,
396
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
397
+ max_sequence_length: int = 226,
398
+ device: Optional[torch.device] = None,
399
+ dtype: Optional[torch.dtype] = None,
400
+ ):
401
+ r"""
402
+ Encodes the prompt into text encoder hidden states.
403
+
404
+ Args:
405
+ prompt (`str` or `List[str]`, *optional*):
406
+ prompt to be encoded
407
+ negative_prompt (`str` or `List[str]`, *optional*):
408
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
409
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
410
+ less than `1`).
411
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
412
+ Whether to use classifier free guidance or not.
413
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
414
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
415
+ prompt_embeds (`torch.Tensor`, *optional*):
416
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
417
+ provided, text embeddings will be generated from `prompt` input argument.
418
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
419
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
420
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
421
+ argument.
422
+ device: (`torch.device`, *optional*):
423
+ torch device
424
+ dtype: (`torch.dtype`, *optional*):
425
+ torch dtype
426
+ """
427
+ device = device or self._execution_device
428
+
429
+ prompt = [prompt] if isinstance(prompt, str) else prompt
430
+ if prompt is not None:
431
+ batch_size = len(prompt)
432
+ else:
433
+ batch_size = prompt_embeds.shape[0]
434
+
435
+ if prompt_embeds is None:
436
+ prompt_embeds = self._get_t5_prompt_embeds(
437
+ prompt=prompt,
438
+ num_videos_per_prompt=num_videos_per_prompt,
439
+ max_sequence_length=max_sequence_length,
440
+ device=device,
441
+ dtype=dtype,
442
+ )
443
+
444
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
445
+ negative_prompt = negative_prompt or ""
446
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
447
+
448
+ if prompt is not None and type(prompt) is not type(negative_prompt):
449
+ raise TypeError(
450
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
451
+ f" {type(prompt)}."
452
+ )
453
+ elif batch_size != len(negative_prompt):
454
+ raise ValueError(
455
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
456
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
457
+ " the batch size of `prompt`."
458
+ )
459
+
460
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
461
+ prompt=negative_prompt,
462
+ num_videos_per_prompt=num_videos_per_prompt,
463
+ max_sequence_length=max_sequence_length,
464
+ device=device,
465
+ dtype=dtype,
466
+ )
467
+
468
+ return prompt_embeds, negative_prompt_embeds
469
+
470
+ def prepare_latents(
471
+ self,
472
+ batch_size,
473
+ num_channels_latents,
474
+ height,
475
+ width,
476
+ video_length,
477
+ dtype,
478
+ device,
479
+ generator,
480
+ latents=None,
481
+ video=None,
482
+ timestep=None,
483
+ is_strength_max=True,
484
+ return_noise=False,
485
+ return_video_latents=False,
486
+ ):
487
+ shape = (
488
+ batch_size,
489
+ (video_length - 1) // self.vae_scale_factor_temporal + 1,
490
+ num_channels_latents,
491
+ height // self.vae_scale_factor_spatial,
492
+ width // self.vae_scale_factor_spatial,
493
+ )
494
+ if isinstance(generator, list) and len(generator) != batch_size:
495
+ raise ValueError(
496
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
497
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
498
+ )
499
+
500
+ if return_video_latents or (latents is None and not is_strength_max):
501
+ video = video.to(device=device, dtype=self.vae.dtype)
502
+
503
+ bs = 1
504
+ new_video = []
505
+ for i in range(0, video.shape[0], bs):
506
+ video_bs = video[i : i + bs]
507
+ video_bs = self.vae.encode(video_bs)[0]
508
+ video_bs = video_bs.sample()
509
+ new_video.append(video_bs)
510
+ video = torch.cat(new_video, dim = 0)
511
+ video = video * self.vae.config.scaling_factor
512
+
513
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
514
+ video_latents = video_latents.to(device=device, dtype=dtype)
515
+ video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
516
+
517
+ if latents is None:
518
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
519
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
520
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
521
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
522
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
523
+ else:
524
+ noise = latents.to(device)
525
+ latents = noise * self.scheduler.init_noise_sigma
526
+
527
+ # scale the initial noise by the standard deviation required by the scheduler
528
+ outputs = (latents,)
529
+
530
+ if return_noise:
531
+ outputs += (noise,)
532
+
533
+ if return_video_latents:
534
+ outputs += (video_latents,)
535
+
536
+ return outputs
537
+
538
+ def prepare_mask_latents(
539
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
540
+ ):
541
+ # resize the mask to latents shape as we concatenate the mask to the latents
542
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
543
+ # and half precision
544
+
545
+ if mask is not None:
546
+ mask = mask.to(device=device, dtype=self.vae.dtype)
547
+ bs = 1
548
+ new_mask = []
549
+ for i in range(0, mask.shape[0], bs):
550
+ mask_bs = mask[i : i + bs]
551
+ mask_bs = self.vae.encode(mask_bs)[0]
552
+ mask_bs = mask_bs.mode()
553
+ new_mask.append(mask_bs)
554
+ mask = torch.cat(new_mask, dim = 0)
555
+ mask = mask * self.vae.config.scaling_factor
556
+
557
+ if masked_image is not None:
558
+ if self.transformer.config.add_noise_in_inpaint_model:
559
+ masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
560
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
561
+ bs = 1
562
+ new_mask_pixel_values = []
563
+ for i in range(0, masked_image.shape[0], bs):
564
+ mask_pixel_values_bs = masked_image[i : i + bs]
565
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
566
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
567
+ new_mask_pixel_values.append(mask_pixel_values_bs)
568
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
569
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
570
+ else:
571
+ masked_image_latents = None
572
+
573
+ return mask, masked_image_latents
574
+
575
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
576
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
577
+ latents = 1 / self.vae.config.scaling_factor * latents
578
+
579
+ frames = self.vae.decode(latents).sample
580
+ frames = (frames / 2 + 0.5).clamp(0, 1)
581
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
582
+ frames = frames.cpu().float().numpy()
583
+ return frames
584
+
585
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
586
+ def prepare_extra_step_kwargs(self, generator, eta):
587
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
588
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
589
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
590
+ # and should be between [0, 1]
591
+
592
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
593
+ extra_step_kwargs = {}
594
+ if accepts_eta:
595
+ extra_step_kwargs["eta"] = eta
596
+
597
+ # check if the scheduler accepts generator
598
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
599
+ if accepts_generator:
600
+ extra_step_kwargs["generator"] = generator
601
+ return extra_step_kwargs
602
+
603
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
604
+ def check_inputs(
605
+ self,
606
+ prompt,
607
+ height,
608
+ width,
609
+ negative_prompt,
610
+ callback_on_step_end_tensor_inputs,
611
+ prompt_embeds=None,
612
+ negative_prompt_embeds=None,
613
+ ):
614
+ if height % 8 != 0 or width % 8 != 0:
615
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
616
+
617
+ if callback_on_step_end_tensor_inputs is not None and not all(
618
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
619
+ ):
620
+ raise ValueError(
621
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
622
+ )
623
+ if prompt is not None and prompt_embeds is not None:
624
+ raise ValueError(
625
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
626
+ " only forward one of the two."
627
+ )
628
+ elif prompt is None and prompt_embeds is None:
629
+ raise ValueError(
630
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
631
+ )
632
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
633
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
634
+
635
+ if prompt is not None and negative_prompt_embeds is not None:
636
+ raise ValueError(
637
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
638
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
639
+ )
640
+
641
+ if negative_prompt is not None and negative_prompt_embeds is not None:
642
+ raise ValueError(
643
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
644
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
645
+ )
646
+
647
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
648
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
649
+ raise ValueError(
650
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
651
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
652
+ f" {negative_prompt_embeds.shape}."
653
+ )
654
+
655
+ def fuse_qkv_projections(self) -> None:
656
+ r"""Enables fused QKV projections."""
657
+ self.fusing_transformer = True
658
+ self.transformer.fuse_qkv_projections()
659
+
660
+ def unfuse_qkv_projections(self) -> None:
661
+ r"""Disable QKV projection fusion if enabled."""
662
+ if not self.fusing_transformer:
663
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
664
+ else:
665
+ self.transformer.unfuse_qkv_projections()
666
+ self.fusing_transformer = False
667
+
668
+ def _prepare_rotary_positional_embeddings(
669
+ self,
670
+ height: int,
671
+ width: int,
672
+ num_frames: int,
673
+ device: torch.device,
674
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
675
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
676
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
677
+
678
+ p = self.transformer.config.patch_size
679
+ p_t = self.transformer.config.patch_size_t
680
+
681
+ base_size_width = self.transformer.config.sample_width // p
682
+ base_size_height = self.transformer.config.sample_height // p
683
+
684
+ if p_t is None:
685
+ # CogVideoX 1.0
686
+ grid_crops_coords = get_resize_crop_region_for_grid(
687
+ (grid_height, grid_width), base_size_width, base_size_height
688
+ )
689
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
690
+ embed_dim=self.transformer.config.attention_head_dim,
691
+ crops_coords=grid_crops_coords,
692
+ grid_size=(grid_height, grid_width),
693
+ temporal_size=num_frames,
694
+ )
695
+ else:
696
+ # CogVideoX 1.5
697
+ base_num_frames = (num_frames + p_t - 1) // p_t
698
+
699
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
700
+ embed_dim=self.transformer.config.attention_head_dim,
701
+ crops_coords=None,
702
+ grid_size=(grid_height, grid_width),
703
+ temporal_size=base_num_frames,
704
+ grid_type="slice",
705
+ max_size=(base_size_height, base_size_width),
706
+ )
707
+
708
+ freqs_cos = freqs_cos.to(device=device)
709
+ freqs_sin = freqs_sin.to(device=device)
710
+ return freqs_cos, freqs_sin
711
+
712
+ @property
713
+ def guidance_scale(self):
714
+ return self._guidance_scale
715
+
716
+ @property
717
+ def num_timesteps(self):
718
+ return self._num_timesteps
719
+
720
+ @property
721
+ def attention_kwargs(self):
722
+ return self._attention_kwargs
723
+
724
+ @property
725
+ def interrupt(self):
726
+ return self._interrupt
727
+
728
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
729
+ def get_timesteps(self, num_inference_steps, strength, device):
730
+ # get the original timestep using init_timestep
731
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
732
+
733
+ t_start = max(num_inference_steps - init_timestep, 0)
734
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
735
+
736
+ return timesteps, num_inference_steps - t_start
737
+
738
+ @torch.no_grad()
739
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
740
+ def __call__(
741
+ self,
742
+ prompt: Optional[Union[str, List[str]]] = None,
743
+ negative_prompt: Optional[Union[str, List[str]]] = None,
744
+ height: int = 480,
745
+ width: int = 720,
746
+ video: Union[torch.FloatTensor] = None,
747
+ mask_video: Union[torch.FloatTensor] = None,
748
+ masked_video_latents: Union[torch.FloatTensor] = None,
749
+ num_frames: int = 49,
750
+ num_inference_steps: int = 50,
751
+ timesteps: Optional[List[int]] = None,
752
+ guidance_scale: float = 6,
753
+ use_dynamic_cfg: bool = False,
754
+ num_videos_per_prompt: int = 1,
755
+ eta: float = 0.0,
756
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
757
+ latents: Optional[torch.FloatTensor] = None,
758
+ prompt_embeds: Optional[torch.FloatTensor] = None,
759
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
760
+ output_type: str = "numpy",
761
+ return_dict: bool = False,
762
+ callback_on_step_end: Optional[
763
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
764
+ ] = None,
765
+ attention_kwargs: Optional[Dict[str, Any]] = None,
766
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
767
+ max_sequence_length: int = 226,
768
+ strength: float = 1,
769
+ noise_aug_strength: float = 0.0563,
770
+ comfyui_progressbar: bool = False,
771
+ temporal_multidiffusion_stride: int = 16,
772
+ use_trimask: bool = False,
773
+ zero_out_mask_region: bool = False,
774
+ binarize_mask: bool = False,
775
+ skip_unet: bool = False,
776
+ use_vae_mask: bool = False,
777
+ stack_mask: bool = False,
778
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
779
+ """
780
+ Function invoked when calling the pipeline for generation.
781
+
782
+ Args:
783
+ prompt (`str` or `List[str]`, *optional*):
784
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
785
+ instead.
786
+ negative_prompt (`str` or `List[str]`, *optional*):
787
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
788
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
789
+ less than `1`).
790
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
791
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
792
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
793
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
794
+ num_frames (`int`, defaults to `48`):
795
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
796
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
797
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
798
+ needs to be satisfied is that of divisibility mentioned above.
799
+ num_inference_steps (`int`, *optional*, defaults to 50):
800
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
801
+ expense of slower inference.
802
+ timesteps (`List[int]`, *optional*):
803
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
804
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
805
+ passed will be used. Must be in descending order.
806
+ guidance_scale (`float`, *optional*, defaults to 7.0):
807
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
808
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
809
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
810
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
811
+ usually at the expense of lower image quality.
812
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
813
+ The number of videos to generate per prompt.
814
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
815
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
816
+ to make generation deterministic.
817
+ latents (`torch.FloatTensor`, *optional*):
818
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
819
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
820
+ tensor will ge generated by sampling using the supplied random `generator`.
821
+ prompt_embeds (`torch.FloatTensor`, *optional*):
822
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
823
+ provided, text embeddings will be generated from `prompt` input argument.
824
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
825
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
826
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
827
+ argument.
828
+ output_type (`str`, *optional*, defaults to `"pil"`):
829
+ The output format of the generate image. Choose between
830
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
831
+ return_dict (`bool`, *optional*, defaults to `True`):
832
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
833
+ of a plain tuple.
834
+ callback_on_step_end (`Callable`, *optional*):
835
+ A function that calls at the end of each denoising steps during the inference. The function is called
836
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
837
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
838
+ `callback_on_step_end_tensor_inputs`.
839
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
840
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
841
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
842
+ `._callback_tensor_inputs` attribute of your pipeline class.
843
+ max_sequence_length (`int`, defaults to `226`):
844
+ Maximum sequence length in encoded prompt. Must be consistent with
845
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
846
+
847
+ Examples:
848
+
849
+ Returns:
850
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
851
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
852
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
853
+ """
854
+
855
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
856
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
857
+
858
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
859
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
860
+ num_frames = num_frames or self.transformer.config.sample_frames
861
+
862
+ num_videos_per_prompt = 1
863
+
864
+ # 1. Check inputs. Raise error if not correct
865
+ self.check_inputs(
866
+ prompt,
867
+ height,
868
+ width,
869
+ negative_prompt,
870
+ callback_on_step_end_tensor_inputs,
871
+ prompt_embeds,
872
+ negative_prompt_embeds,
873
+ )
874
+ self._guidance_scale = guidance_scale
875
+ self._attention_kwargs = attention_kwargs
876
+ self._interrupt = False
877
+
878
+ # 2. Default call parameters
879
+ if prompt is not None and isinstance(prompt, str):
880
+ batch_size = 1
881
+ elif prompt is not None and isinstance(prompt, list):
882
+ batch_size = len(prompt)
883
+ else:
884
+ batch_size = prompt_embeds.shape[0]
885
+
886
+ device = self._execution_device
887
+
888
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
889
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
890
+ # corresponds to doing no classifier free guidance.
891
+ do_classifier_free_guidance = guidance_scale > 1.0
892
+ logger.info(f'Use cfg: {do_classifier_free_guidance}, guidance_scale={guidance_scale}')
893
+
894
+ # 3. Encode input prompt
895
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
896
+ prompt,
897
+ negative_prompt,
898
+ do_classifier_free_guidance,
899
+ num_videos_per_prompt=num_videos_per_prompt,
900
+ prompt_embeds=prompt_embeds,
901
+ negative_prompt_embeds=negative_prompt_embeds,
902
+ max_sequence_length=max_sequence_length,
903
+ device=device,
904
+ )
905
+ if do_classifier_free_guidance:
906
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
907
+
908
+ # 4. set timesteps
909
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
910
+ timesteps, num_inference_steps = self.get_timesteps(
911
+ num_inference_steps=num_inference_steps, strength=strength, device=device
912
+ )
913
+ self._num_timesteps = len(timesteps)
914
+ if comfyui_progressbar:
915
+ from comfy.utils import ProgressBar
916
+ pbar = ProgressBar(num_inference_steps + 2)
917
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
918
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
919
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
920
+ is_strength_max = strength == 1.0
921
+
922
+ # 5. Prepare latents.
923
+ if video is not None:
924
+ video_length = video.shape[2]
925
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
926
+ init_video = init_video.to(dtype=torch.float32)
927
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
928
+ else:
929
+ video_length = num_frames
930
+ init_video = None
931
+
932
+ # Magvae needs the number of frames to be 4n + 1.
933
+ local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
934
+ # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
935
+ patch_size_t = self.transformer.config.patch_size_t
936
+ additional_frames = 0
937
+ if patch_size_t is not None and local_latent_length % patch_size_t != 0:
938
+ additional_frames = local_latent_length % patch_size_t
939
+ num_frames -= additional_frames * self.vae_scale_factor_temporal
940
+ if num_frames <= 0:
941
+ num_frames = 1
942
+
943
+ num_channels_latents = self.vae.config.latent_channels
944
+ num_channels_transformer = self.transformer.config.in_channels
945
+ return_image_latents = num_channels_transformer == num_channels_latents
946
+
947
+ latents_outputs = self.prepare_latents(
948
+ batch_size * num_videos_per_prompt,
949
+ num_channels_latents,
950
+ height,
951
+ width,
952
+ video_length,
953
+ prompt_embeds.dtype,
954
+ device,
955
+ generator,
956
+ latents,
957
+ video=init_video,
958
+ timestep=latent_timestep,
959
+ is_strength_max=is_strength_max,
960
+ return_noise=True,
961
+ return_video_latents=return_image_latents,
962
+ )
963
+ if return_image_latents:
964
+ latents, noise, image_latents = latents_outputs
965
+ else:
966
+ latents, noise = latents_outputs
967
+ if comfyui_progressbar:
968
+ pbar.update(1)
969
+
970
+ if mask_video is not None:
971
+ if (mask_video == 255).all():
972
+ mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
973
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
974
+
975
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
976
+ masked_video_latents_input = (
977
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
978
+ )
979
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
980
+ else:
981
+ # Prepare mask latent variables
982
+ video_length = video.shape[2]
983
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
984
+ if use_trimask:
985
+ mask_condition = torch.where(mask_condition > 0.75, 1., mask_condition)
986
+ mask_condition = torch.where((mask_condition <= 0.75) * (mask_condition >= 0.25), 127. / 255., mask_condition)
987
+ mask_condition = torch.where(mask_condition < 0.25, 0., mask_condition)
988
+ else:
989
+ mask_condition = torch.where(mask_condition > 0.5, 1., 0.)
990
+
991
+ mask_condition = mask_condition.to(dtype=torch.float32)
992
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
993
+
994
+ if num_channels_transformer != num_channels_latents:
995
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
996
+ if masked_video_latents is None:
997
+ if zero_out_mask_region:
998
+ masked_video = init_video * (mask_condition_tile < 0.75) + torch.ones_like(init_video) * (mask_condition_tile > 0.75) * -1
999
+ else:
1000
+ masked_video = init_video
1001
+ else:
1002
+ masked_video = masked_video_latents
1003
+
1004
+ mask_encoded, masked_video_latents = self.prepare_mask_latents(
1005
+ 1 - mask_condition_tile if use_vae_mask else None,
1006
+ masked_video,
1007
+ batch_size,
1008
+ height,
1009
+ width,
1010
+ prompt_embeds.dtype,
1011
+ device,
1012
+ generator,
1013
+ do_classifier_free_guidance,
1014
+ noise_aug_strength=noise_aug_strength,
1015
+ )
1016
+ if not use_vae_mask and not stack_mask:
1017
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
1018
+ if binarize_mask:
1019
+ if use_trimask:
1020
+ mask_latents = torch.where(mask_latents > 0.75, 1., mask_latents)
1021
+ mask_latents = torch.where((mask_latents <= 0.75) * (mask_latents >= 0.25), 0.5, mask_latents)
1022
+ mask_latents = torch.where(mask_latents < 0.25, 0., mask_latents)
1023
+ else:
1024
+ mask_latents = torch.where(mask_latents < 0.9, 0., 1.).to(mask_latents.dtype)
1025
+
1026
+ mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
1027
+
1028
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1029
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1030
+
1031
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
1032
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1033
+ elif stack_mask:
1034
+ mask_latents = torch.cat([
1035
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
1036
+ mask_condition[:, :, 1:],
1037
+ ], dim=2)
1038
+ mask_latents = mask_latents.view(
1039
+ mask_latents.shape[0],
1040
+ mask_latents.shape[2] // 4,
1041
+ 4,
1042
+ mask_latents.shape[3],
1043
+ mask_latents.shape[4],
1044
+ )
1045
+ mask_latents = mask_latents.transpose(1, 2)
1046
+ mask_latents = resize_mask(1 - mask_latents, masked_video_latents).to(latents.device, latents.dtype)
1047
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
1048
+ else:
1049
+ mask_input = (
1050
+ torch.cat([mask_encoded] * 2) if do_classifier_free_guidance else mask_encoded
1051
+ )
1052
+
1053
+ masked_video_latents_input = (
1054
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
1055
+ )
1056
+
1057
+ mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
1058
+ masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
1059
+
1060
+ # concat(binary mask, encode(mask * video))
1061
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
1062
+ else:
1063
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1064
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1065
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1066
+
1067
+ inpaint_latents = None
1068
+ else:
1069
+ if num_channels_transformer != num_channels_latents:
1070
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
1071
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1072
+
1073
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
1074
+ masked_video_latents_input = (
1075
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
1076
+ )
1077
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1078
+ else:
1079
+ mask = torch.zeros_like(init_video[:, :1])
1080
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
1081
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1082
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1083
+
1084
+ inpaint_latents = None
1085
+ if comfyui_progressbar:
1086
+ pbar.update(1)
1087
+
1088
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1089
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1090
+ logger.debug(f'Pipeline mask {mask_condition.shape} {mask_condition.dtype} {mask_condition.min()} {mask_condition.max()}')
1091
+ # 8. Denoising loop
1092
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1093
+ latent_temporal_window_size = (num_frames - 1) // 4 + 1
1094
+ if latents.size(1) > latent_temporal_window_size:
1095
+ logger.info(f'Adopt temporal multidiffusion for the latents {latents.shape} {latents.dtype}')
1096
+
1097
+ # VAE experiment
1098
+ if skip_unet:
1099
+ masked_video_latents = rearrange(masked_video_latents, "b c f h w -> b f c h w")
1100
+ if output_type == "numpy":
1101
+ video = self.decode_latents(masked_video_latents)
1102
+ elif not output_type == "latent":
1103
+ video = self.decode_latents(masked_video_latents)
1104
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
1105
+ else:
1106
+ video = masked_video_latents
1107
+
1108
+ # Offload all models
1109
+ self.maybe_free_model_hooks()
1110
+
1111
+ if not return_dict:
1112
+ video = torch.from_numpy(video)
1113
+
1114
+ return CogVideoXFunPipelineOutput(videos=video)
1115
+
1116
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1117
+ # for DPM-solver++
1118
+ old_pred_original_sample = None
1119
+ for i, t in enumerate(timesteps):
1120
+ if self.interrupt:
1121
+ continue
1122
+
1123
+ def _sample(_latents, _inpaint_latents):
1124
+ # 7. Create rotary embeds if required
1125
+ image_rotary_emb = (
1126
+ self._prepare_rotary_positional_embeddings(height, width, _latents.size(1), device)
1127
+ if self.transformer.config.use_rotary_positional_embeddings
1128
+ else None
1129
+ )
1130
+
1131
+ latent_model_input = torch.cat([_latents] * 2) if do_classifier_free_guidance else _latents
1132
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1133
+
1134
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1135
+ timestep = t.expand(latent_model_input.shape[0])
1136
+
1137
+ # predict noise model_output
1138
+ noise_pred = self.transformer(
1139
+ hidden_states=latent_model_input,
1140
+ encoder_hidden_states=prompt_embeds,
1141
+ timestep=timestep,
1142
+ image_rotary_emb=image_rotary_emb,
1143
+ return_dict=False,
1144
+ inpaint_latents=_inpaint_latents,
1145
+ )[0]
1146
+ noise_pred = noise_pred.float()
1147
+
1148
+ # perform guidance
1149
+ if use_dynamic_cfg:
1150
+ self._guidance_scale = 1 + guidance_scale * (
1151
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
1152
+ )
1153
+ if do_classifier_free_guidance:
1154
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1155
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1156
+
1157
+ # compute the previous noisy sample x_t -> x_t-1
1158
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
1159
+ _latents = self.scheduler.step(noise_pred, t, _latents, **extra_step_kwargs, return_dict=False)[0]
1160
+ else:
1161
+ _latents, old_pred_original_sample = self.scheduler.step(
1162
+ noise_pred,
1163
+ old_pred_original_sample,
1164
+ t,
1165
+ timesteps[i - 1] if i > 0 else None,
1166
+ _latents,
1167
+ **extra_step_kwargs,
1168
+ return_dict=False,
1169
+ )
1170
+ _latents = _latents.to(prompt_embeds.dtype)
1171
+ return _latents
1172
+
1173
+ if latents.size(1) <= latent_temporal_window_size:
1174
+ latents = _sample(latents, inpaint_latents)
1175
+ else:
1176
+ # adopt temporal multidiffusion
1177
+ latents_canvas = torch.zeros_like(latents).float()
1178
+ weights_canvas = torch.zeros(1, latents.size(1), 1, 1, 1).to(latents.device).float()
1179
+ temporal_stride = temporal_multidiffusion_stride // 4
1180
+ assert latent_temporal_window_size > temporal_stride
1181
+
1182
+ time_beg = 0
1183
+ while time_beg < latents.size(1):
1184
+ time_end = min(time_beg + latent_temporal_window_size, latents.size(1))
1185
+
1186
+ latents_i = latents[:, time_beg:time_end]
1187
+ if inpaint_latents is not None:
1188
+ inpaint_latents_i = inpaint_latents[:, time_beg:time_end]
1189
+ else:
1190
+ inpaint_latents_i = None
1191
+
1192
+ latents_i = _sample(latents_i, inpaint_latents_i)
1193
+
1194
+ weights_i = torch.ones(1, time_end - time_beg, 1, 1, 1).to(latents.device).to(latents.dtype)
1195
+ if time_beg > 0 and temporal_stride > 0:
1196
+ weights_i[:, :temporal_stride] = (torch.linspace(0., 1., temporal_stride + 2)[1:-1]
1197
+ .to(latents.device)
1198
+ .to(latents.dtype)
1199
+ .reshape(1, temporal_stride, 1, 1, 1))
1200
+ if time_end < latents.size(1) and temporal_stride > 0:
1201
+ weights_i[:, -temporal_stride:] = (torch.linspace(1., 0., temporal_stride + 2)[1:-1]
1202
+ .to(latents.device)
1203
+ .to(latents.dtype)
1204
+ .reshape(1, temporal_stride, 1, 1, 1))
1205
+
1206
+ latents_canvas[:, time_beg:time_end] += latents_i * weights_i
1207
+ weights_canvas[:, time_beg:time_end] += weights_i
1208
+
1209
+ time_beg = time_end - temporal_stride
1210
+ if time_end >= latents.size(1):
1211
+ break
1212
+ latents = (latents_canvas / weights_canvas).to(latents.dtype)
1213
+
1214
+ # call the callback, if provided
1215
+ if callback_on_step_end is not None:
1216
+ callback_kwargs = {}
1217
+ for k in callback_on_step_end_tensor_inputs:
1218
+ callback_kwargs[k] = locals()[k]
1219
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1220
+
1221
+ latents = callback_outputs.pop("latents", latents)
1222
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1223
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1224
+
1225
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1226
+ progress_bar.update()
1227
+ if comfyui_progressbar:
1228
+ pbar.update(1)
1229
+
1230
+ if output_type == "numpy":
1231
+ video = self.decode_latents(latents)
1232
+ elif not output_type == "latent":
1233
+ video = self.decode_latents(latents)
1234
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
1235
+ else:
1236
+ video = latents
1237
+
1238
+ # Offload all models
1239
+ self.maybe_free_model_hooks()
1240
+
1241
+ if not return_dict:
1242
+ video = torch.from_numpy(video)
1243
+
1244
+ return CogVideoXFunPipelineOutput(videos=video)
videox_fun/pipeline/pipeline_wan_fun.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.video_processor import VideoProcessor
14
+
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
16
+ WanT5EncoderModel, WanTransformer3DModel)
17
+
18
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
+
20
+
21
+ EXAMPLE_DOC_STRING = """
22
+ Examples:
23
+ ```python
24
+ pass
25
+ ```
26
+ """
27
+
28
+
29
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
30
+ def retrieve_timesteps(
31
+ scheduler,
32
+ num_inference_steps: Optional[int] = None,
33
+ device: Optional[Union[str, torch.device]] = None,
34
+ timesteps: Optional[List[int]] = None,
35
+ sigmas: Optional[List[float]] = None,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
40
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
41
+
42
+ Args:
43
+ scheduler (`SchedulerMixin`):
44
+ The scheduler to get timesteps from.
45
+ num_inference_steps (`int`):
46
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
47
+ must be `None`.
48
+ device (`str` or `torch.device`, *optional*):
49
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
50
+ timesteps (`List[int]`, *optional*):
51
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
52
+ `num_inference_steps` and `sigmas` must be `None`.
53
+ sigmas (`List[float]`, *optional*):
54
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
55
+ `num_inference_steps` and `timesteps` must be `None`.
56
+
57
+ Returns:
58
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
59
+ second element is the number of inference steps.
60
+ """
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
63
+ if timesteps is not None:
64
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
65
+ if not accepts_timesteps:
66
+ raise ValueError(
67
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
68
+ f" timestep schedules. Please check whether you are using the correct scheduler."
69
+ )
70
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
71
+ timesteps = scheduler.timesteps
72
+ num_inference_steps = len(timesteps)
73
+ elif sigmas is not None:
74
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
75
+ if not accept_sigmas:
76
+ raise ValueError(
77
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
78
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
79
+ )
80
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
81
+ timesteps = scheduler.timesteps
82
+ num_inference_steps = len(timesteps)
83
+ else:
84
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
85
+ timesteps = scheduler.timesteps
86
+ return timesteps, num_inference_steps
87
+
88
+
89
+ @dataclass
90
+ class WanPipelineOutput(BaseOutput):
91
+ r"""
92
+ Output class for CogVideo pipelines.
93
+
94
+ Args:
95
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
96
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
97
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
98
+ `(batch_size, num_frames, channels, height, width)`.
99
+ """
100
+
101
+ videos: torch.Tensor
102
+
103
+
104
+ class WanFunPipeline(DiffusionPipeline):
105
+ r"""
106
+ Pipeline for text-to-video generation using Wan.
107
+
108
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
109
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
110
+ """
111
+
112
+ _optional_components = []
113
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
114
+
115
+ _callback_tensor_inputs = [
116
+ "latents",
117
+ "prompt_embeds",
118
+ "negative_prompt_embeds",
119
+ ]
120
+
121
+ def __init__(
122
+ self,
123
+ tokenizer: AutoTokenizer,
124
+ text_encoder: WanT5EncoderModel,
125
+ vae: AutoencoderKLWan,
126
+ transformer: WanTransformer3DModel,
127
+ scheduler: FlowMatchEulerDiscreteScheduler,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.register_modules(
132
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
133
+ )
134
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
135
+
136
+ def _get_t5_prompt_embeds(
137
+ self,
138
+ prompt: Union[str, List[str]] = None,
139
+ num_videos_per_prompt: int = 1,
140
+ max_sequence_length: int = 512,
141
+ device: Optional[torch.device] = None,
142
+ dtype: Optional[torch.dtype] = None,
143
+ ):
144
+ device = device or self._execution_device
145
+ dtype = dtype or self.text_encoder.dtype
146
+
147
+ prompt = [prompt] if isinstance(prompt, str) else prompt
148
+ batch_size = len(prompt)
149
+
150
+ text_inputs = self.tokenizer(
151
+ prompt,
152
+ padding="max_length",
153
+ max_length=max_sequence_length,
154
+ truncation=True,
155
+ add_special_tokens=True,
156
+ return_tensors="pt",
157
+ )
158
+ text_input_ids = text_inputs.input_ids
159
+ prompt_attention_mask = text_inputs.attention_mask
160
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
161
+
162
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
163
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
164
+ logger.warning(
165
+ "The following part of your input was truncated because `max_sequence_length` is set to "
166
+ f" {max_sequence_length} tokens: {removed_text}"
167
+ )
168
+
169
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
170
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
171
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
172
+
173
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
174
+ _, seq_len, _ = prompt_embeds.shape
175
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
176
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
177
+
178
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
179
+
180
+ def encode_prompt(
181
+ self,
182
+ prompt: Union[str, List[str]],
183
+ negative_prompt: Optional[Union[str, List[str]]] = None,
184
+ do_classifier_free_guidance: bool = True,
185
+ num_videos_per_prompt: int = 1,
186
+ prompt_embeds: Optional[torch.Tensor] = None,
187
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
188
+ max_sequence_length: int = 512,
189
+ device: Optional[torch.device] = None,
190
+ dtype: Optional[torch.dtype] = None,
191
+ ):
192
+ r"""
193
+ Encodes the prompt into text encoder hidden states.
194
+
195
+ Args:
196
+ prompt (`str` or `List[str]`, *optional*):
197
+ prompt to be encoded
198
+ negative_prompt (`str` or `List[str]`, *optional*):
199
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
200
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
201
+ less than `1`).
202
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
203
+ Whether to use classifier free guidance or not.
204
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
205
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
206
+ prompt_embeds (`torch.Tensor`, *optional*):
207
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
208
+ provided, text embeddings will be generated from `prompt` input argument.
209
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
210
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
211
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
212
+ argument.
213
+ device: (`torch.device`, *optional*):
214
+ torch device
215
+ dtype: (`torch.dtype`, *optional*):
216
+ torch dtype
217
+ """
218
+ device = device or self._execution_device
219
+
220
+ prompt = [prompt] if isinstance(prompt, str) else prompt
221
+ if prompt is not None:
222
+ batch_size = len(prompt)
223
+ else:
224
+ batch_size = prompt_embeds.shape[0]
225
+
226
+ if prompt_embeds is None:
227
+ prompt_embeds = self._get_t5_prompt_embeds(
228
+ prompt=prompt,
229
+ num_videos_per_prompt=num_videos_per_prompt,
230
+ max_sequence_length=max_sequence_length,
231
+ device=device,
232
+ dtype=dtype,
233
+ )
234
+
235
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
236
+ negative_prompt = negative_prompt or ""
237
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
238
+
239
+ if prompt is not None and type(prompt) is not type(negative_prompt):
240
+ raise TypeError(
241
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
242
+ f" {type(prompt)}."
243
+ )
244
+ elif batch_size != len(negative_prompt):
245
+ raise ValueError(
246
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
247
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
248
+ " the batch size of `prompt`."
249
+ )
250
+
251
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
252
+ prompt=negative_prompt,
253
+ num_videos_per_prompt=num_videos_per_prompt,
254
+ max_sequence_length=max_sequence_length,
255
+ device=device,
256
+ dtype=dtype,
257
+ )
258
+
259
+ return prompt_embeds, negative_prompt_embeds
260
+
261
+ def prepare_latents(
262
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
263
+ ):
264
+ if isinstance(generator, list) and len(generator) != batch_size:
265
+ raise ValueError(
266
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
267
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
268
+ )
269
+
270
+ shape = (
271
+ batch_size,
272
+ num_channels_latents,
273
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
274
+ height // self.vae.spacial_compression_ratio,
275
+ width // self.vae.spacial_compression_ratio,
276
+ )
277
+
278
+ if latents is None:
279
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
280
+ else:
281
+ latents = latents.to(device)
282
+
283
+ # scale the initial noise by the standard deviation required by the scheduler
284
+ if hasattr(self.scheduler, "init_noise_sigma"):
285
+ latents = latents * self.scheduler.init_noise_sigma
286
+ return latents
287
+
288
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
289
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
290
+ frames = (frames / 2 + 0.5).clamp(0, 1)
291
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
292
+ frames = frames.cpu().float().numpy()
293
+ return frames
294
+
295
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
296
+ def prepare_extra_step_kwargs(self, generator, eta):
297
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
298
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
299
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
300
+ # and should be between [0, 1]
301
+
302
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
303
+ extra_step_kwargs = {}
304
+ if accepts_eta:
305
+ extra_step_kwargs["eta"] = eta
306
+
307
+ # check if the scheduler accepts generator
308
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
309
+ if accepts_generator:
310
+ extra_step_kwargs["generator"] = generator
311
+ return extra_step_kwargs
312
+
313
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
314
+ def check_inputs(
315
+ self,
316
+ prompt,
317
+ height,
318
+ width,
319
+ negative_prompt,
320
+ callback_on_step_end_tensor_inputs,
321
+ prompt_embeds=None,
322
+ negative_prompt_embeds=None,
323
+ ):
324
+ if height % 8 != 0 or width % 8 != 0:
325
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
326
+
327
+ if callback_on_step_end_tensor_inputs is not None and not all(
328
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
329
+ ):
330
+ raise ValueError(
331
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
332
+ )
333
+ if prompt is not None and prompt_embeds is not None:
334
+ raise ValueError(
335
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
336
+ " only forward one of the two."
337
+ )
338
+ elif prompt is None and prompt_embeds is None:
339
+ raise ValueError(
340
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
341
+ )
342
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
343
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
344
+
345
+ if prompt is not None and negative_prompt_embeds is not None:
346
+ raise ValueError(
347
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
348
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
349
+ )
350
+
351
+ if negative_prompt is not None and negative_prompt_embeds is not None:
352
+ raise ValueError(
353
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
354
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
355
+ )
356
+
357
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
358
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
359
+ raise ValueError(
360
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
361
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
362
+ f" {negative_prompt_embeds.shape}."
363
+ )
364
+
365
+ @property
366
+ def guidance_scale(self):
367
+ return self._guidance_scale
368
+
369
+ @property
370
+ def num_timesteps(self):
371
+ return self._num_timesteps
372
+
373
+ @property
374
+ def attention_kwargs(self):
375
+ return self._attention_kwargs
376
+
377
+ @property
378
+ def interrupt(self):
379
+ return self._interrupt
380
+
381
+ @torch.no_grad()
382
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
383
+ def __call__(
384
+ self,
385
+ prompt: Optional[Union[str, List[str]]] = None,
386
+ negative_prompt: Optional[Union[str, List[str]]] = None,
387
+ height: int = 480,
388
+ width: int = 720,
389
+ num_frames: int = 49,
390
+ num_inference_steps: int = 50,
391
+ timesteps: Optional[List[int]] = None,
392
+ guidance_scale: float = 6,
393
+ num_videos_per_prompt: int = 1,
394
+ eta: float = 0.0,
395
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
396
+ latents: Optional[torch.FloatTensor] = None,
397
+ prompt_embeds: Optional[torch.FloatTensor] = None,
398
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
399
+ output_type: str = "numpy",
400
+ return_dict: bool = False,
401
+ callback_on_step_end: Optional[
402
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
403
+ ] = None,
404
+ attention_kwargs: Optional[Dict[str, Any]] = None,
405
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
406
+ max_sequence_length: int = 512,
407
+ comfyui_progressbar: bool = False,
408
+ ) -> Union[WanPipelineOutput, Tuple]:
409
+ """
410
+ Function invoked when calling the pipeline for generation.
411
+ Args:
412
+
413
+ Examples:
414
+
415
+ Returns:
416
+
417
+ """
418
+
419
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
420
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
421
+ num_videos_per_prompt = 1
422
+
423
+ # 1. Check inputs. Raise error if not correct
424
+ self.check_inputs(
425
+ prompt,
426
+ height,
427
+ width,
428
+ negative_prompt,
429
+ callback_on_step_end_tensor_inputs,
430
+ prompt_embeds,
431
+ negative_prompt_embeds,
432
+ )
433
+ self._guidance_scale = guidance_scale
434
+ self._attention_kwargs = attention_kwargs
435
+ self._interrupt = False
436
+
437
+ # 2. Default call parameters
438
+ if prompt is not None and isinstance(prompt, str):
439
+ batch_size = 1
440
+ elif prompt is not None and isinstance(prompt, list):
441
+ batch_size = len(prompt)
442
+ else:
443
+ batch_size = prompt_embeds.shape[0]
444
+
445
+ device = self._execution_device
446
+ weight_dtype = self.text_encoder.dtype
447
+
448
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
449
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
450
+ # corresponds to doing no classifier free guidance.
451
+ do_classifier_free_guidance = guidance_scale > 1.0
452
+
453
+ # 3. Encode input prompt
454
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
455
+ prompt,
456
+ negative_prompt,
457
+ do_classifier_free_guidance,
458
+ num_videos_per_prompt=num_videos_per_prompt,
459
+ prompt_embeds=prompt_embeds,
460
+ negative_prompt_embeds=negative_prompt_embeds,
461
+ max_sequence_length=max_sequence_length,
462
+ device=device,
463
+ )
464
+ if do_classifier_free_guidance:
465
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
466
+
467
+ # 4. Prepare timesteps
468
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
469
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
470
+ else:
471
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
472
+ self._num_timesteps = len(timesteps)
473
+ if comfyui_progressbar:
474
+ from comfy.utils import ProgressBar
475
+ pbar = ProgressBar(num_inference_steps + 1)
476
+
477
+ # 5. Prepare latents
478
+ latent_channels = self.transformer.config.in_channels
479
+ latents = self.prepare_latents(
480
+ batch_size * num_videos_per_prompt,
481
+ latent_channels,
482
+ num_frames,
483
+ height,
484
+ width,
485
+ weight_dtype,
486
+ device,
487
+ generator,
488
+ latents,
489
+ )
490
+ if comfyui_progressbar:
491
+ pbar.update(1)
492
+
493
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
494
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
495
+
496
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
497
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
498
+ # 7. Denoising loop
499
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
500
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
501
+ for i, t in enumerate(timesteps):
502
+ if self.interrupt:
503
+ continue
504
+
505
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
506
+ if hasattr(self.scheduler, "scale_model_input"):
507
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
508
+
509
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
510
+ timestep = t.expand(latent_model_input.shape[0])
511
+
512
+ # predict noise model_output
513
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
514
+ noise_pred = self.transformer(
515
+ x=latent_model_input,
516
+ context=prompt_embeds,
517
+ t=timestep,
518
+ seq_len=seq_len,
519
+ )
520
+
521
+ # perform guidance
522
+ if do_classifier_free_guidance:
523
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
524
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
525
+
526
+ # compute the previous noisy sample x_t -> x_t-1
527
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
528
+
529
+ if callback_on_step_end is not None:
530
+ callback_kwargs = {}
531
+ for k in callback_on_step_end_tensor_inputs:
532
+ callback_kwargs[k] = locals()[k]
533
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
534
+
535
+ latents = callback_outputs.pop("latents", latents)
536
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
537
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
538
+
539
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
540
+ progress_bar.update()
541
+ if comfyui_progressbar:
542
+ pbar.update(1)
543
+
544
+ if output_type == "numpy":
545
+ video = self.decode_latents(latents)
546
+ elif not output_type == "latent":
547
+ video = self.decode_latents(latents)
548
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
549
+ else:
550
+ video = latents
551
+
552
+ # Offload all models
553
+ self.maybe_free_model_hooks()
554
+
555
+ if not return_dict:
556
+ video = torch.from_numpy(video)
557
+
558
+ return WanPipelineOutput(videos=video)
videox_fun/reward/MPS/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This folder is modified from the official [MPS](https://github.com/Kwai-Kolors/MPS/tree/main) repository.
videox_fun/reward/MPS/trainer/models/base_model.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+
5
+ @dataclass
6
+ class BaseModelConfig:
7
+ pass
videox_fun/reward/MPS/trainer/models/clip_model.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from transformers import CLIPModel as HFCLIPModel
3
+ from transformers import AutoTokenizer
4
+
5
+ from torch import nn, einsum
6
+
7
+ # Modified: import
8
+ # from trainer.models.base_model import BaseModelConfig
9
+ from .base_model import BaseModelConfig
10
+
11
+ from transformers import CLIPConfig
12
+ from typing import Any, Optional, Tuple, Union
13
+ import torch
14
+
15
+ # Modified: import
16
+ # from trainer.models.cross_modeling import Cross_model
17
+ from .cross_modeling import Cross_model
18
+
19
+ import gc
20
+
21
+ class XCLIPModel(HFCLIPModel):
22
+ def __init__(self, config: CLIPConfig):
23
+ super().__init__(config)
24
+
25
+ def get_text_features(
26
+ self,
27
+ input_ids: Optional[torch.Tensor] = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.Tensor] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ ) -> torch.FloatTensor:
34
+
35
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
36
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
37
+ output_hidden_states = (
38
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
39
+ )
40
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
41
+
42
+ text_outputs = self.text_model(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
+ position_ids=position_ids,
46
+ output_attentions=output_attentions,
47
+ output_hidden_states=output_hidden_states,
48
+ return_dict=return_dict,
49
+ )
50
+
51
+ # pooled_output = text_outputs[1]
52
+ # text_features = self.text_projection(pooled_output)
53
+ last_hidden_state = text_outputs[0]
54
+ text_features = self.text_projection(last_hidden_state)
55
+
56
+ pooled_output = text_outputs[1]
57
+ text_features_EOS = self.text_projection(pooled_output)
58
+
59
+
60
+ # del last_hidden_state, text_outputs
61
+ # gc.collect()
62
+
63
+ return text_features, text_features_EOS
64
+
65
+ def get_image_features(
66
+ self,
67
+ pixel_values: Optional[torch.FloatTensor] = None,
68
+ output_attentions: Optional[bool] = None,
69
+ output_hidden_states: Optional[bool] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> torch.FloatTensor:
72
+
73
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
74
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
+ output_hidden_states = (
76
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
77
+ )
78
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
+
80
+ vision_outputs = self.vision_model(
81
+ pixel_values=pixel_values,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=output_hidden_states,
84
+ return_dict=return_dict,
85
+ )
86
+
87
+ # pooled_output = vision_outputs[1] # pooled_output
88
+ # image_features = self.visual_projection(pooled_output)
89
+ last_hidden_state = vision_outputs[0]
90
+ image_features = self.visual_projection(last_hidden_state)
91
+
92
+ return image_features
93
+
94
+
95
+
96
+ @dataclass
97
+ class ClipModelConfig(BaseModelConfig):
98
+ _target_: str = "trainer.models.clip_model.CLIPModel"
99
+ pretrained_model_name_or_path: str ="openai/clip-vit-base-patch32"
100
+
101
+
102
+ class CLIPModel(nn.Module):
103
+ def __init__(self, config):
104
+ super().__init__()
105
+ # Modified: We convert the original ckpt (contains the entire model) to a `state_dict`.
106
+ # self.model = XCLIPModel.from_pretrained(ckpt)
107
+ self.model = XCLIPModel(config)
108
+ self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
109
+
110
+ def get_text_features(self, *args, **kwargs):
111
+ return self.model.get_text_features(*args, **kwargs)
112
+
113
+ def get_image_features(self, *args, **kwargs):
114
+ return self.model.get_image_features(*args, **kwargs)
115
+
116
+ def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
117
+ outputs = ()
118
+
119
+ text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
120
+ outputs += text_EOS,
121
+
122
+ image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
123
+ # [B, 77, 1024]
124
+ condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
125
+
126
+ sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
127
+ sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
128
+ sim_text_condition = sim_text_condition / sim_text_condition.max()
129
+ mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
130
+
131
+ # Modified: Support both torch.float16 and torch.bfloat16
132
+ # mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
133
+ model_dtype = next(self.cross_model.parameters()).dtype
134
+ mask = mask.repeat(1,image_f.shape[1],1).to(model_dtype) # B*257*77
135
+ # bc = int(image_f.shape[0]/2)
136
+
137
+ # Modified: The original input consists of a (batch of) text and two (batches of) images,
138
+ # primarily used to compute which (batch of) image is more consistent with the text.
139
+ # The modified input consists of a (batch of) text and a (batch of) images.
140
+ # sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
141
+ # sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
142
+ # outputs += sim0[:,0,:],
143
+ # outputs += sim1[:,0,:],
144
+ sim = self.cross_model(image_f, text_f,mask)
145
+ outputs += sim[:,0,:],
146
+
147
+ return outputs
148
+
149
+ @property
150
+ def logit_scale(self):
151
+ return self.model.logit_scale
152
+
153
+ def save(self, path):
154
+ self.model.save_pretrained(path)
videox_fun/reward/MPS/trainer/models/cross_modeling.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import einsum, nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+
6
+ # helper functions
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+ def default(val, d):
12
+ return val if exists(val) else d
13
+
14
+ # normalization
15
+ # they use layernorm without bias, something that pytorch does not offer
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ def __init__(self, dim):
20
+ super().__init__()
21
+ self.weight = nn.Parameter(torch.ones(dim))
22
+ self.register_buffer("bias", torch.zeros(dim))
23
+
24
+ def forward(self, x):
25
+ return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
26
+
27
+ # residual
28
+
29
+
30
+ class Residual(nn.Module):
31
+ def __init__(self, fn):
32
+ super().__init__()
33
+ self.fn = fn
34
+
35
+ def forward(self, x, *args, **kwargs):
36
+ return self.fn(x, *args, **kwargs) + x
37
+
38
+
39
+ # rotary positional embedding
40
+ # https://arxiv.org/abs/2104.09864
41
+
42
+
43
+ class RotaryEmbedding(nn.Module):
44
+ def __init__(self, dim):
45
+ super().__init__()
46
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
47
+ self.register_buffer("inv_freq", inv_freq)
48
+
49
+ def forward(self, max_seq_len, *, device):
50
+ seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
51
+ freqs = einsum("i , j -> i j", seq, self.inv_freq)
52
+ return torch.cat((freqs, freqs), dim=-1)
53
+
54
+
55
+ def rotate_half(x):
56
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
57
+ x1, x2 = x.unbind(dim=-2)
58
+ return torch.cat((-x2, x1), dim=-1)
59
+
60
+
61
+ def apply_rotary_pos_emb(pos, t):
62
+ return (t * pos.cos()) + (rotate_half(t) * pos.sin())
63
+
64
+
65
+ # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
66
+ # https://arxiv.org/abs/2002.05202
67
+
68
+
69
+ class SwiGLU(nn.Module):
70
+ def forward(self, x):
71
+ x, gate = x.chunk(2, dim=-1)
72
+ return F.silu(gate) * x
73
+
74
+
75
+ # parallel attention and feedforward with residual
76
+ # discovered by Wang et al + EleutherAI from GPT-J fame
77
+
78
+ class ParallelTransformerBlock(nn.Module):
79
+ def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
80
+ super().__init__()
81
+ self.norm = LayerNorm(dim)
82
+
83
+ attn_inner_dim = dim_head * heads
84
+ ff_inner_dim = dim * ff_mult
85
+ self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
86
+
87
+ self.heads = heads
88
+ self.scale = dim_head**-0.5
89
+ self.rotary_emb = RotaryEmbedding(dim_head)
90
+
91
+ self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
92
+ self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
93
+
94
+ self.ff_out = nn.Sequential(
95
+ SwiGLU(),
96
+ nn.Linear(ff_inner_dim, dim, bias=False)
97
+ )
98
+
99
+ self.register_buffer("pos_emb", None, persistent=False)
100
+
101
+
102
+ def get_rotary_embedding(self, n, device):
103
+ if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
104
+ return self.pos_emb[:n]
105
+
106
+ pos_emb = self.rotary_emb(n, device=device)
107
+ self.register_buffer("pos_emb", pos_emb, persistent=False)
108
+ return pos_emb
109
+
110
+ def forward(self, x, attn_mask=None):
111
+ """
112
+ einstein notation
113
+ b - batch
114
+ h - heads
115
+ n, i, j - sequence length (base sequence length, source, target)
116
+ d - feature dimension
117
+ """
118
+
119
+ n, device, h = x.shape[1], x.device, self.heads
120
+
121
+ # pre layernorm
122
+
123
+ x = self.norm(x)
124
+
125
+ # attention queries, keys, values, and feedforward inner
126
+
127
+ q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
128
+
129
+ # split heads
130
+ # they use multi-query single-key-value attention, yet another Noam Shazeer paper
131
+ # they found no performance loss past a certain scale, and more efficient decoding obviously
132
+ # https://arxiv.org/abs/1911.02150
133
+
134
+ q = rearrange(q, "b n (h d) -> b h n d", h=h)
135
+
136
+ # rotary embeddings
137
+
138
+ positions = self.get_rotary_embedding(n, device)
139
+ q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
140
+
141
+ # scale
142
+
143
+ q = q * self.scale
144
+
145
+ # similarity
146
+
147
+ sim = einsum("b h i d, b j d -> b h i j", q, k)
148
+
149
+
150
+ # extra attention mask - for masking out attention from text CLS token to padding
151
+
152
+ if exists(attn_mask):
153
+ attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
154
+ sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
155
+
156
+ # attention
157
+
158
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
159
+ attn = sim.softmax(dim=-1)
160
+
161
+ # aggregate values
162
+
163
+ out = einsum("b h i j, b j d -> b h i d", attn, v)
164
+
165
+ # merge heads
166
+
167
+ out = rearrange(out, "b h n d -> b n (h d)")
168
+ return self.attn_out(out) + self.ff_out(ff)
169
+
170
+ # cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
171
+
172
+ class CrossAttention(nn.Module):
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ *,
177
+ context_dim=None,
178
+ dim_head=64,
179
+ heads=12,
180
+ parallel_ff=False,
181
+ ff_mult=4,
182
+ norm_context=False
183
+ ):
184
+ super().__init__()
185
+ self.heads = heads
186
+ self.scale = dim_head ** -0.5
187
+ inner_dim = heads * dim_head
188
+ context_dim = default(context_dim, dim)
189
+
190
+ self.norm = LayerNorm(dim)
191
+ self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
192
+
193
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
194
+ self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
195
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
196
+
197
+ # whether to have parallel feedforward
198
+
199
+ ff_inner_dim = ff_mult * dim
200
+
201
+ self.ff = nn.Sequential(
202
+ nn.Linear(dim, ff_inner_dim * 2, bias=False),
203
+ SwiGLU(),
204
+ nn.Linear(ff_inner_dim, dim, bias=False)
205
+ ) if parallel_ff else None
206
+
207
+ def forward(self, x, context, mask):
208
+ """
209
+ einstein notation
210
+ b - batch
211
+ h - heads
212
+ n, i, j - sequence length (base sequence length, source, target)
213
+ d - feature dimension
214
+ """
215
+
216
+ # pre-layernorm, for queries and context
217
+
218
+ x = self.norm(x)
219
+ context = self.context_norm(context)
220
+
221
+ # get queries
222
+
223
+ q = self.to_q(x)
224
+ q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
225
+
226
+ # scale
227
+
228
+ q = q * self.scale
229
+
230
+ # get key / values
231
+
232
+ k, v = self.to_kv(context).chunk(2, dim=-1)
233
+
234
+ # query / key similarity
235
+
236
+ sim = einsum('b h i d, b j d -> b h i j', q, k)
237
+
238
+ # attention
239
+ mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
240
+ sim = sim + mask # context mask
241
+ sim = sim - sim.amax(dim=-1, keepdim=True)
242
+ attn = sim.softmax(dim=-1)
243
+
244
+ # aggregate
245
+
246
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
247
+
248
+ # merge and combine heads
249
+
250
+ out = rearrange(out, 'b h n d -> b n (h d)')
251
+ out = self.to_out(out)
252
+
253
+ # add parallel feedforward (for multimodal layers)
254
+
255
+ if exists(self.ff):
256
+ out = out + self.ff(x)
257
+
258
+ return out
259
+
260
+
261
+ class Cross_model(nn.Module):
262
+ def __init__(
263
+ self,
264
+ dim=512,
265
+ layer_num=4,
266
+ dim_head=64,
267
+ heads=8,
268
+ ff_mult=4
269
+ ):
270
+ super().__init__()
271
+
272
+ self.layers = nn.ModuleList([])
273
+
274
+
275
+ for ind in range(layer_num):
276
+ self.layers.append(nn.ModuleList([
277
+ Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
278
+ Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
279
+ ]))
280
+
281
+ def forward(
282
+ self,
283
+ query_tokens,
284
+ context_tokens,
285
+ mask
286
+ ):
287
+ for cross_attn, self_attn_ff in self.layers:
288
+ query_tokens = cross_attn(query_tokens, context_tokens,mask)
289
+ query_tokens = self_attn_ff(query_tokens)
290
+
291
+ return query_tokens
videox_fun/reward/aesthetic_predictor_v2_5/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .siglip_v2_5 import (
2
+ AestheticPredictorV2_5Head,
3
+ AestheticPredictorV2_5Model,
4
+ AestheticPredictorV2_5Processor,
5
+ convert_v2_5_from_siglip,
6
+ )
7
+
8
+ __all__ = [
9
+ "AestheticPredictorV2_5Head",
10
+ "AestheticPredictorV2_5Model",
11
+ "AestheticPredictorV2_5Processor",
12
+ "convert_v2_5_from_siglip",
13
+ ]
videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py
2
+ import os
3
+ from collections import OrderedDict
4
+ from os import PathLike
5
+ from typing import Final
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as transforms
10
+ from transformers import (
11
+ SiglipImageProcessor,
12
+ SiglipVisionConfig,
13
+ SiglipVisionModel,
14
+ logging,
15
+ )
16
+ from transformers.image_processing_utils import BatchFeature
17
+ from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
18
+
19
+ logging.set_verbosity_error()
20
+
21
+ URL: Final[str] = (
22
+ "https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth"
23
+ )
24
+
25
+
26
+ class AestheticPredictorV2_5Head(nn.Module):
27
+ def __init__(self, config: SiglipVisionConfig) -> None:
28
+ super().__init__()
29
+ self.scoring_head = nn.Sequential(
30
+ nn.Linear(config.hidden_size, 1024),
31
+ nn.Dropout(0.5),
32
+ nn.Linear(1024, 128),
33
+ nn.Dropout(0.5),
34
+ nn.Linear(128, 64),
35
+ nn.Dropout(0.5),
36
+ nn.Linear(64, 16),
37
+ nn.Dropout(0.2),
38
+ nn.Linear(16, 1),
39
+ )
40
+
41
+ def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
42
+ return self.scoring_head(image_embeds)
43
+
44
+
45
+ class AestheticPredictorV2_5Model(SiglipVisionModel):
46
+ PATCH_SIZE = 14
47
+
48
+ def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None:
49
+ super().__init__(config, *args, **kwargs)
50
+ self.layers = AestheticPredictorV2_5Head(config)
51
+ self.post_init()
52
+ self.transforms = transforms.Compose([
53
+ transforms.Resize((384, 384)),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
56
+ ])
57
+
58
+ def forward(
59
+ self,
60
+ pixel_values: torch.FloatTensor | None = None,
61
+ labels: torch.Tensor | None = None,
62
+ return_dict: bool | None = None,
63
+ ) -> tuple | ImageClassifierOutputWithNoAttention:
64
+ return_dict = (
65
+ return_dict if return_dict is not None else self.config.use_return_dict
66
+ )
67
+
68
+ outputs = super().forward(
69
+ pixel_values=pixel_values,
70
+ return_dict=return_dict,
71
+ )
72
+ image_embeds = outputs.pooler_output
73
+ image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
74
+ prediction = self.layers(image_embeds_norm)
75
+
76
+ loss = None
77
+ if labels is not None:
78
+ loss_fct = nn.MSELoss()
79
+ loss = loss_fct()
80
+
81
+ if not return_dict:
82
+ return (loss, prediction, image_embeds)
83
+
84
+ return ImageClassifierOutputWithNoAttention(
85
+ loss=loss,
86
+ logits=prediction,
87
+ hidden_states=image_embeds,
88
+ )
89
+
90
+
91
+ class AestheticPredictorV2_5Processor(SiglipImageProcessor):
92
+ def __init__(self, *args, **kwargs) -> None:
93
+ super().__init__(*args, **kwargs)
94
+
95
+ def __call__(self, *args, **kwargs) -> BatchFeature:
96
+ return super().__call__(*args, **kwargs)
97
+
98
+ @classmethod
99
+ def from_pretrained(
100
+ self,
101
+ pretrained_model_name_or_path: str
102
+ | PathLike = "google/siglip-so400m-patch14-384",
103
+ *args,
104
+ **kwargs,
105
+ ) -> "AestheticPredictorV2_5Processor":
106
+ return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
107
+
108
+
109
+ def convert_v2_5_from_siglip(
110
+ predictor_name_or_path: str | PathLike | None = None,
111
+ encoder_model_name: str = "google/siglip-so400m-patch14-384",
112
+ *args,
113
+ **kwargs,
114
+ ) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]:
115
+ model = AestheticPredictorV2_5Model.from_pretrained(
116
+ encoder_model_name, *args, **kwargs
117
+ )
118
+
119
+ processor = AestheticPredictorV2_5Processor.from_pretrained(
120
+ encoder_model_name, *args, **kwargs
121
+ )
122
+
123
+ if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path):
124
+ state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu")
125
+ else:
126
+ state_dict = torch.load(predictor_name_or_path, map_location="cpu")
127
+
128
+ assert isinstance(state_dict, OrderedDict)
129
+
130
+ model.layers.load_state_dict(state_dict)
131
+ model.eval()
132
+
133
+ return model, processor
videox_fun/reward/improved_aesthetic_predictor.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import CLIPModel
6
+ from torchvision.datasets.utils import download_url
7
+
8
+ URL = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/sac%2Blogos%2Bava1-l14-linearMSE.pth"
9
+ FILENAME = "sac+logos+ava1-l14-linearMSE.pth"
10
+ MD5 = "b1047fd767a00134b8fd6529bf19521a"
11
+
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.layers = nn.Sequential(
17
+ nn.Linear(768, 1024),
18
+ nn.Dropout(0.2),
19
+ nn.Linear(1024, 128),
20
+ nn.Dropout(0.2),
21
+ nn.Linear(128, 64),
22
+ nn.Dropout(0.1),
23
+ nn.Linear(64, 16),
24
+ nn.Linear(16, 1),
25
+ )
26
+
27
+
28
+ def forward(self, embed):
29
+ return self.layers(embed)
30
+
31
+
32
+ class ImprovedAestheticPredictor(nn.Module):
33
+ def __init__(self, encoder_path="openai/clip-vit-large-patch14", predictor_path=None):
34
+ super().__init__()
35
+ self.encoder = CLIPModel.from_pretrained(encoder_path)
36
+ self.predictor = MLP()
37
+ if predictor_path is None or not os.path.exists(predictor_path):
38
+ download_url(URL, torch.hub.get_dir(), FILENAME, md5=MD5)
39
+ predictor_path = os.path.join(torch.hub.get_dir(), FILENAME)
40
+ state_dict = torch.load(predictor_path, map_location="cpu")
41
+ self.predictor.load_state_dict(state_dict)
42
+ self.eval()
43
+
44
+
45
+ def forward(self, pixel_values):
46
+ embed = self.encoder.get_image_features(pixel_values=pixel_values)
47
+ embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
48
+
49
+ return self.predictor(embed).squeeze(1)
videox_fun/reward/reward_fn.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from einops import rearrange
7
+ from torchvision.datasets.utils import download_url
8
+ from typing import Optional, Tuple
9
+
10
+
11
+ # All reward models.
12
+ __all__ = ["AestheticReward", "HPSReward", "PickScoreReward", "MPSReward"]
13
+
14
+
15
+ class BaseReward(ABC):
16
+ """An base class for reward models. A custom Reward class must implement two functions below.
17
+ """
18
+ def __init__(self):
19
+ """Define your reward model and image transformations (optional) here.
20
+ """
21
+ pass
22
+
23
+ @abstractmethod
24
+ def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ """Given batch frames with shape `[B, C, T, H, W]` extracted from a list of videos and a list of prompts
26
+ (optional) correspondingly, return the loss and reward computed by your reward model (reduction by mean).
27
+ """
28
+ pass
29
+
30
+ class AestheticReward(BaseReward):
31
+ """Aesthetic Predictor [V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)
32
+ and [V2.5](https://github.com/discus0434/aesthetic-predictor-v2-5) reward model.
33
+ """
34
+ def __init__(
35
+ self,
36
+ encoder_path="openai/clip-vit-large-patch14",
37
+ predictor_path=None,
38
+ version="v2",
39
+ device="cpu",
40
+ dtype=torch.float16,
41
+ max_reward=10,
42
+ loss_scale=0.1,
43
+ ):
44
+ from .improved_aesthetic_predictor import ImprovedAestheticPredictor
45
+ from ..video_caption.utils.siglip_v2_5 import convert_v2_5_from_siglip
46
+
47
+ self.encoder_path = encoder_path
48
+ self.predictor_path = predictor_path
49
+ self.version = version
50
+ self.device = device
51
+ self.dtype = dtype
52
+ self.max_reward = max_reward
53
+ self.loss_scale = loss_scale
54
+
55
+ if self.version != "v2" and self.version != "v2.5":
56
+ raise ValueError("Only v2 and v2.5 are supported.")
57
+ if self.version == "v2":
58
+ assert "clip-vit-large-patch14" in encoder_path.lower()
59
+ self.model = ImprovedAestheticPredictor(encoder_path=self.encoder_path, predictor_path=self.predictor_path)
60
+ # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/preprocessor_config.json
61
+ # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
62
+ self.transform = transforms.Compose([
63
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
64
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
65
+ ])
66
+ elif self.version == "v2.5":
67
+ assert "siglip-so400m-patch14-384" in encoder_path.lower()
68
+ self.model, _ = convert_v2_5_from_siglip(encoder_model_name=self.encoder_path)
69
+ # https://huggingface.co/google/siglip-so400m-patch14-384/blob/main/preprocessor_config.json
70
+ self.transform = transforms.Compose([
71
+ transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC),
72
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
73
+ ])
74
+
75
+ self.model.to(device=self.device, dtype=self.dtype)
76
+ self.model.requires_grad_(False)
77
+
78
+
79
+ def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
81
+ batch_loss, batch_reward = 0, 0
82
+ for frames in batch_frames:
83
+ pixel_values = torch.stack([self.transform(frame) for frame in frames])
84
+ pixel_values = pixel_values.to(self.device, dtype=self.dtype)
85
+ if self.version == "v2":
86
+ reward = self.model(pixel_values)
87
+ elif self.version == "v2.5":
88
+ reward = self.model(pixel_values).logits.squeeze()
89
+ # Convert reward to loss in [0, 1].
90
+ if self.max_reward is None:
91
+ loss = (-1 * reward) * self.loss_scale
92
+ else:
93
+ loss = abs(reward - self.max_reward) * self.loss_scale
94
+ batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
95
+
96
+ return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
97
+
98
+
99
+ class HPSReward(BaseReward):
100
+ """[HPS](https://github.com/tgxs002/HPSv2) v2 and v2.1 reward model.
101
+ """
102
+ def __init__(
103
+ self,
104
+ model_path=None,
105
+ version="v2.0",
106
+ device="cpu",
107
+ dtype=torch.float16,
108
+ max_reward=1,
109
+ loss_scale=1,
110
+ ):
111
+ from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
112
+
113
+ self.model_path = model_path
114
+ self.version = version
115
+ self.device = device
116
+ self.dtype = dtype
117
+ self.max_reward = max_reward
118
+ self.loss_scale = loss_scale
119
+
120
+ self.model, _, _ = create_model_and_transforms(
121
+ "ViT-H-14",
122
+ "laion2B-s32B-b79K",
123
+ precision=self.dtype,
124
+ device=self.device,
125
+ jit=False,
126
+ force_quick_gelu=False,
127
+ force_custom_text=False,
128
+ force_patch_dropout=False,
129
+ force_image_size=None,
130
+ pretrained_image=False,
131
+ image_mean=None,
132
+ image_std=None,
133
+ light_augmentation=True,
134
+ aug_cfg={},
135
+ output_dict=True,
136
+ with_score_predictor=False,
137
+ with_region_predictor=False,
138
+ )
139
+ self.tokenizer = get_tokenizer("ViT-H-14")
140
+
141
+ # https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json
142
+ # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
143
+ self.transform = transforms.Compose([
144
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
145
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
146
+ ])
147
+
148
+ if version == "v2.0":
149
+ url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2_compressed.pt"
150
+ filename = "HPS_v2_compressed.pt"
151
+ md5 = "fd9180de357abf01fdb4eaad64631db4"
152
+ elif version == "v2.1":
153
+ url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2.1_compressed.pt"
154
+ filename = "HPS_v2.1_compressed.pt"
155
+ md5 = "4067542e34ba2553a738c5ac6c1d75c0"
156
+ else:
157
+ raise ValueError("Only v2.0 and v2.1 are supported.")
158
+ if self.model_path is None or not os.path.exists(self.model_path):
159
+ download_url(url, torch.hub.get_dir(), md5=md5)
160
+ model_path = os.path.join(torch.hub.get_dir(), filename)
161
+
162
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
163
+ self.model.load_state_dict(state_dict)
164
+ self.model.to(device=self.device, dtype=self.dtype)
165
+ self.model.requires_grad_(False)
166
+ self.model.eval()
167
+
168
+ def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
169
+ assert batch_frames.shape[0] == len(batch_prompt)
170
+ # Compute batch reward and loss in frame-wise.
171
+ batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
172
+ batch_loss, batch_reward = 0, 0
173
+ for frames in batch_frames:
174
+ image_inputs = torch.stack([self.transform(frame) for frame in frames])
175
+ image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
176
+ text_inputs = self.tokenizer(batch_prompt).to(device=self.device)
177
+ outputs = self.model(image_inputs, text_inputs)
178
+
179
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
180
+ logits = image_features @ text_features.T
181
+ reward = torch.diagonal(logits)
182
+ # Convert reward to loss in [0, 1].
183
+ if self.max_reward is None:
184
+ loss = (-1 * reward) * self.loss_scale
185
+ else:
186
+ loss = abs(reward - self.max_reward) * self.loss_scale
187
+
188
+ batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
189
+
190
+ return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
191
+
192
+
193
+ class PickScoreReward(BaseReward):
194
+ """[PickScore](https://github.com/yuvalkirstain/PickScore) reward model.
195
+ """
196
+ def __init__(
197
+ self,
198
+ model_path="yuvalkirstain/PickScore_v1",
199
+ device="cpu",
200
+ dtype=torch.float16,
201
+ max_reward=1,
202
+ loss_scale=1,
203
+ ):
204
+ from transformers import AutoProcessor, AutoModel
205
+
206
+ self.model_path = model_path
207
+ self.device = device
208
+ self.dtype = dtype
209
+ self.max_reward = max_reward
210
+ self.loss_scale = loss_scale
211
+
212
+ # https://huggingface.co/yuvalkirstain/PickScore_v1/blob/main/preprocessor_config.json
213
+ self.transform = transforms.Compose([
214
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
215
+ transforms.CenterCrop(224),
216
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
217
+ ])
218
+ self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype)
219
+ self.model = AutoModel.from_pretrained(model_path, torch_dtype=self.dtype).eval().to(device)
220
+ self.model.requires_grad_(False)
221
+ self.model.eval()
222
+
223
+ def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
224
+ assert batch_frames.shape[0] == len(batch_prompt)
225
+ # Compute batch reward and loss in frame-wise.
226
+ batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
227
+ batch_loss, batch_reward = 0, 0
228
+ for frames in batch_frames:
229
+ image_inputs = torch.stack([self.transform(frame) for frame in frames])
230
+ image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
231
+ text_inputs = self.processor(
232
+ text=batch_prompt,
233
+ padding=True,
234
+ truncation=True,
235
+ max_length=77,
236
+ return_tensors="pt",
237
+ ).to(self.device)
238
+ image_features = self.model.get_image_features(pixel_values=image_inputs)
239
+ text_features = self.model.get_text_features(**text_inputs)
240
+ image_features = image_features / torch.norm(image_features, dim=-1, keepdim=True)
241
+ text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True)
242
+
243
+ logits = image_features @ text_features.T
244
+ reward = torch.diagonal(logits)
245
+ # Convert reward to loss in [0, 1].
246
+ if self.max_reward is None:
247
+ loss = (-1 * reward) * self.loss_scale
248
+ else:
249
+ loss = abs(reward - self.max_reward) * self.loss_scale
250
+
251
+ batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
252
+
253
+ return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
254
+
255
+
256
+ class MPSReward(BaseReward):
257
+ """[MPS](https://github.com/Kwai-Kolors/MPS) reward model.
258
+ """
259
+ def __init__(
260
+ self,
261
+ model_path=None,
262
+ device="cpu",
263
+ dtype=torch.float16,
264
+ max_reward=1,
265
+ loss_scale=1,
266
+ ):
267
+ from transformers import AutoTokenizer, AutoConfig
268
+ from .MPS.trainer.models.clip_model import CLIPModel
269
+
270
+ self.model_path = model_path
271
+ self.device = device
272
+ self.dtype = dtype
273
+ self.condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things."
274
+ self.max_reward = max_reward
275
+ self.loss_scale = loss_scale
276
+
277
+ processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
278
+ # https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json
279
+ # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
280
+ self.transform = transforms.Compose([
281
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
282
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
283
+ ])
284
+
285
+ # We convert the original [ckpt](http://drive.google.com/file/d/17qrK_aJkVNM75ZEvMEePpLj6L867MLkN/view?usp=sharing)
286
+ # (contains the entire model) to a `state_dict`.
287
+ url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/MPS_overall.pth"
288
+ filename = "MPS_overall.pth"
289
+ md5 = "1491cbbbd20565747fe07e7572e2ac56"
290
+ if self.model_path is None or not os.path.exists(self.model_path):
291
+ download_url(url, torch.hub.get_dir(), md5=md5)
292
+ model_path = os.path.join(torch.hub.get_dir(), filename)
293
+
294
+ self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
295
+ config = AutoConfig.from_pretrained(processor_name_or_path)
296
+ self.model = CLIPModel(config)
297
+ state_dict = torch.load(model_path, map_location="cpu")
298
+ self.model.load_state_dict(state_dict, strict=False)
299
+ self.model.to(device=self.device, dtype=self.dtype)
300
+ self.model.requires_grad_(False)
301
+ self.model.eval()
302
+
303
+ def _tokenize(self, caption):
304
+ input_ids = self.tokenizer(
305
+ caption,
306
+ max_length=self.tokenizer.model_max_length,
307
+ padding="max_length",
308
+ truncation=True,
309
+ return_tensors="pt"
310
+ ).input_ids
311
+
312
+ return input_ids
313
+
314
+ def __call__(
315
+ self,
316
+ batch_frames: torch.Tensor,
317
+ batch_prompt: list[str],
318
+ batch_condition: Optional[list[str]] = None
319
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
320
+ if batch_condition is None:
321
+ batch_condition = [self.condition] * len(batch_prompt)
322
+ batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
323
+ batch_loss, batch_reward = 0, 0
324
+ for frames in batch_frames:
325
+ image_inputs = torch.stack([self.transform(frame) for frame in frames])
326
+ image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
327
+ text_inputs = self._tokenize(batch_prompt).to(self.device)
328
+ condition_inputs = self._tokenize(batch_condition).to(device=self.device)
329
+ text_features, image_features = self.model(text_inputs, image_inputs, condition_inputs)
330
+
331
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
332
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
333
+ # reward = self.model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_features))
334
+ logits = image_features @ text_features.T
335
+ reward = torch.diagonal(logits)
336
+ # Convert reward to loss in [0, 1].
337
+ if self.max_reward is None:
338
+ loss = (-1 * reward) * self.loss_scale
339
+ else:
340
+ loss = abs(reward - self.max_reward) * self.loss_scale
341
+
342
+ batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
343
+
344
+ return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
345
+
346
+
347
+ if __name__ == "__main__":
348
+ import numpy as np
349
+ from decord import VideoReader
350
+
351
+ video_path_list = ["your_video_path_1.mp4", "your_video_path_2.mp4"]
352
+ prompt_list = ["your_prompt_1", "your_prompt_2"]
353
+ num_sampled_frames = 8
354
+
355
+ to_tensor = transforms.ToTensor()
356
+
357
+ sampled_frames_list = []
358
+ for video_path in video_path_list:
359
+ vr = VideoReader(video_path)
360
+ sampled_frame_indices = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
361
+ sampled_frames = vr.get_batch(sampled_frame_indices).asnumpy()
362
+ sampled_frames = torch.stack([to_tensor(frame) for frame in sampled_frames])
363
+ sampled_frames_list.append(sampled_frames)
364
+ sampled_frames = torch.stack(sampled_frames_list)
365
+ sampled_frames = rearrange(sampled_frames, "b t c h w -> b c t h w")
366
+
367
+ aesthetic_reward_v2 = AestheticReward(device="cuda", dtype=torch.bfloat16)
368
+ print(f"aesthetic_reward_v2: {aesthetic_reward_v2(sampled_frames)}")
369
+
370
+ aesthetic_reward_v2_5 = AestheticReward(
371
+ encoder_path="google/siglip-so400m-patch14-384", version="v2.5", device="cuda", dtype=torch.bfloat16
372
+ )
373
+ print(f"aesthetic_reward_v2_5: {aesthetic_reward_v2_5(sampled_frames)}")
374
+
375
+ hps_reward_v2 = HPSReward(device="cuda", dtype=torch.bfloat16)
376
+ print(f"hps_reward_v2: {hps_reward_v2(sampled_frames, prompt_list)}")
377
+
378
+ hps_reward_v2_1 = HPSReward(version="v2.1", device="cuda", dtype=torch.bfloat16)
379
+ print(f"hps_reward_v2_1: {hps_reward_v2_1(sampled_frames, prompt_list)}")
380
+
381
+ pick_score = PickScoreReward(device="cuda", dtype=torch.bfloat16)
382
+ print(f"pick_score_reward: {pick_score(sampled_frames, prompt_list)}")
383
+
384
+ mps_score = MPSReward(device="cuda", dtype=torch.bfloat16)
385
+ print(f"mps_reward: {mps_score(sampled_frames, prompt_list)}")
videox_fun/ui/cogvideox_fun_ui.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import os
4
+ import random
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from safetensors import safe_open
12
+
13
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
14
+ from ..models import (AutoencoderKLCogVideoX, CogVideoXTransformer3DModel,
15
+ T5EncoderModel, T5Tokenizer)
16
+ from ..pipeline import (CogVideoXFunControlPipeline,
17
+ CogVideoXFunInpaintPipeline, CogVideoXFunPipeline)
18
+ from ..utils.fp8_optimization import convert_weight_dtype_wrapper
19
+ from ..utils.lora_utils import merge_lora, unmerge_lora
20
+ from ..utils.utils import (get_image_to_video_latent,
21
+ get_video_to_video_latent, save_videos_grid)
22
+ from .controller import (Fun_Controller, Fun_Controller_Client,
23
+ all_cheduler_dict, css, ddpm_scheduler_dict,
24
+ flow_scheduler_dict, gradio_version,
25
+ gradio_version_is_above_4)
26
+ from .ui import (create_cfg_and_seedbox,
27
+ create_fake_finetune_models_checkpoints,
28
+ create_fake_height_width, create_fake_model_checkpoints,
29
+ create_fake_model_type, create_finetune_models_checkpoints,
30
+ create_generation_method,
31
+ create_generation_methods_and_video_length,
32
+ create_height_width, create_model_checkpoints,
33
+ create_model_type, create_prompts, create_samplers,
34
+ create_ui_outputs)
35
+
36
+
37
+ class CogVideoXFunController(Fun_Controller):
38
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
39
+ print("Update diffusion transformer")
40
+ self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
41
+ if diffusion_transformer_dropdown == "none":
42
+ return gr.update()
43
+ self.vae = AutoencoderKLCogVideoX.from_pretrained(
44
+ diffusion_transformer_dropdown,
45
+ subfolder="vae",
46
+ ).to(self.weight_dtype)
47
+
48
+ # Get Transformer
49
+ self.transformer = CogVideoXTransformer3DModel.from_pretrained(
50
+ diffusion_transformer_dropdown,
51
+ subfolder="transformer",
52
+ low_cpu_mem_usage=True,
53
+ ).to(self.weight_dtype)
54
+
55
+ # Get tokenizer and text_encoder
56
+ tokenizer = T5Tokenizer.from_pretrained(
57
+ diffusion_transformer_dropdown, subfolder="tokenizer"
58
+ )
59
+ text_encoder = T5EncoderModel.from_pretrained(
60
+ diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
61
+ )
62
+
63
+ # Get pipeline
64
+ if self.model_type == "Inpaint":
65
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
66
+ self.pipeline = CogVideoXFunInpaintPipeline(
67
+ tokenizer=tokenizer,
68
+ text_encoder=text_encoder,
69
+ vae=self.vae,
70
+ transformer=self.transformer,
71
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
72
+ )
73
+ else:
74
+ self.pipeline = CogVideoXFunPipeline(
75
+ tokenizer=tokenizer,
76
+ text_encoder=text_encoder,
77
+ vae=self.vae,
78
+ transformer=self.transformer,
79
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
80
+ )
81
+ else:
82
+ self.pipeline = CogVideoXFunControlPipeline(
83
+ diffusion_transformer_dropdown,
84
+ vae=self.vae,
85
+ transformer=self.transformer,
86
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
87
+ torch_dtype=self.weight_dtype
88
+ )
89
+
90
+ if self.ulysses_degree > 1 or self.ring_degree > 1:
91
+ self.transformer.enable_multi_gpus_inference()
92
+
93
+ if self.GPU_memory_mode == "sequential_cpu_offload":
94
+ self.pipeline.enable_sequential_cpu_offload(device=self.device)
95
+ elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
96
+ convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
97
+ self.pipeline.enable_model_cpu_offload(device=self.device)
98
+ elif self.GPU_memory_mode == "model_cpu_offload":
99
+ self.pipeline.enable_model_cpu_offload(device=self.device)
100
+ else:
101
+ self.pipeline.to(self.device)
102
+ print("Update diffusion transformer done")
103
+ return gr.update()
104
+
105
+ def generate(
106
+ self,
107
+ diffusion_transformer_dropdown,
108
+ base_model_dropdown,
109
+ lora_model_dropdown,
110
+ lora_alpha_slider,
111
+ prompt_textbox,
112
+ negative_prompt_textbox,
113
+ sampler_dropdown,
114
+ sample_step_slider,
115
+ resize_method,
116
+ width_slider,
117
+ height_slider,
118
+ base_resolution,
119
+ generation_method,
120
+ length_slider,
121
+ overlap_video_length,
122
+ partial_video_length,
123
+ cfg_scale_slider,
124
+ start_image,
125
+ end_image,
126
+ validation_video,
127
+ validation_video_mask,
128
+ control_video,
129
+ denoise_strength,
130
+ seed_textbox,
131
+ is_api = False,
132
+ ):
133
+ self.clear_cache()
134
+
135
+ self.input_check(
136
+ resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
137
+ )
138
+ is_image = True if generation_method == "Image Generation" else False
139
+
140
+ if self.base_model_path != base_model_dropdown:
141
+ self.update_base_model(base_model_dropdown)
142
+
143
+ if self.lora_model_path != lora_model_dropdown:
144
+ self.update_lora_model(lora_model_dropdown)
145
+
146
+ self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
147
+
148
+ if resize_method == "Resize according to Reference":
149
+ height_slider, width_slider = self.get_height_width_from_reference(
150
+ base_resolution, start_image, validation_video, control_video,
151
+ )
152
+ if self.lora_model_path != "none":
153
+ # lora part
154
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
155
+
156
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
157
+ else: seed_textbox = np.random.randint(0, 1e10)
158
+ generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
159
+
160
+ try:
161
+ if self.model_type == "Inpaint":
162
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
163
+ if generation_method == "Long Video Generation":
164
+ if validation_video is not None:
165
+ raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
166
+ init_frames = 0
167
+ last_frames = init_frames + partial_video_length
168
+ while init_frames < length_slider:
169
+ if last_frames >= length_slider:
170
+ _partial_video_length = length_slider - init_frames
171
+ _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
172
+
173
+ if _partial_video_length <= 0:
174
+ break
175
+ else:
176
+ _partial_video_length = partial_video_length
177
+
178
+ if last_frames >= length_slider:
179
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
180
+ else:
181
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
182
+
183
+ with torch.no_grad():
184
+ sample = self.pipeline(
185
+ prompt_textbox,
186
+ negative_prompt = negative_prompt_textbox,
187
+ num_inference_steps = sample_step_slider,
188
+ guidance_scale = cfg_scale_slider,
189
+ width = width_slider,
190
+ height = height_slider,
191
+ num_frames = _partial_video_length,
192
+ generator = generator,
193
+
194
+ video = input_video,
195
+ mask_video = input_video_mask,
196
+ strength = 1,
197
+ ).videos
198
+
199
+ if init_frames != 0:
200
+ mix_ratio = torch.from_numpy(
201
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
202
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
203
+
204
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
205
+ sample[:, :, :overlap_video_length] * mix_ratio
206
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
207
+
208
+ sample = new_sample
209
+ else:
210
+ new_sample = sample
211
+
212
+ if last_frames >= length_slider:
213
+ break
214
+
215
+ start_image = [
216
+ Image.fromarray(
217
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
218
+ ) for _index in range(-overlap_video_length, 0)
219
+ ]
220
+
221
+ init_frames = init_frames + _partial_video_length - overlap_video_length
222
+ last_frames = init_frames + _partial_video_length
223
+ else:
224
+ if validation_video is not None:
225
+ input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
226
+ strength = denoise_strength
227
+ else:
228
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
229
+ strength = 1
230
+
231
+ sample = self.pipeline(
232
+ prompt_textbox,
233
+ negative_prompt = negative_prompt_textbox,
234
+ num_inference_steps = sample_step_slider,
235
+ guidance_scale = cfg_scale_slider,
236
+ width = width_slider,
237
+ height = height_slider,
238
+ num_frames = length_slider if not is_image else 1,
239
+ generator = generator,
240
+
241
+ video = input_video,
242
+ mask_video = input_video_mask,
243
+ strength = strength,
244
+ ).videos
245
+ else:
246
+ sample = self.pipeline(
247
+ prompt_textbox,
248
+ negative_prompt = negative_prompt_textbox,
249
+ num_inference_steps = sample_step_slider,
250
+ guidance_scale = cfg_scale_slider,
251
+ width = width_slider,
252
+ height = height_slider,
253
+ num_frames = length_slider if not is_image else 1,
254
+ generator = generator
255
+ ).videos
256
+ else:
257
+ input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
258
+
259
+ sample = self.pipeline(
260
+ prompt_textbox,
261
+ negative_prompt = negative_prompt_textbox,
262
+ num_inference_steps = sample_step_slider,
263
+ guidance_scale = cfg_scale_slider,
264
+ width = width_slider,
265
+ height = height_slider,
266
+ num_frames = length_slider if not is_image else 1,
267
+ generator = generator,
268
+
269
+ control_video = input_video,
270
+ ).videos
271
+ except Exception as e:
272
+ self.clear_cache()
273
+ if self.lora_model_path != "none":
274
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
275
+ if is_api:
276
+ return "", f"Error. error information is {str(e)}"
277
+ else:
278
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
279
+
280
+ self.clear_cache()
281
+ # lora part
282
+ if self.lora_model_path != "none":
283
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
284
+
285
+ save_sample_path = self.save_outputs(
286
+ is_image, length_slider, sample, fps=8
287
+ )
288
+
289
+ if is_image or length_slider == 1:
290
+ if is_api:
291
+ return save_sample_path, "Success"
292
+ else:
293
+ if gradio_version_is_above_4:
294
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
295
+ else:
296
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
297
+ else:
298
+ if is_api:
299
+ return save_sample_path, "Success"
300
+ else:
301
+ if gradio_version_is_above_4:
302
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
303
+ else:
304
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
305
+
306
+ CogVideoXFunController_Host = CogVideoXFunController
307
+ CogVideoXFunController_Client = Fun_Controller_Client
308
+
309
+ def ui(GPU_memory_mode, scheduler_dict, ulysses_degree, ring_degree, weight_dtype, savedir_sample=None):
310
+ controller = CogVideoXFunController(
311
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
312
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree,
313
+ config_path=None, enable_teacache=None, teacache_threshold=None, weight_dtype=weight_dtype,
314
+ savedir_sample=savedir_sample,
315
+ )
316
+
317
+ with gr.Blocks(css=css) as demo:
318
+ gr.Markdown(
319
+ """
320
+ # CogVideoX-Fun:
321
+
322
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
323
+
324
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
325
+ """
326
+ )
327
+ with gr.Column(variant="panel"):
328
+ model_type = create_model_type(visible=True)
329
+ diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
330
+ create_model_checkpoints(controller, visible=True)
331
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
332
+ create_finetune_models_checkpoints(controller, visible=True)
333
+
334
+ with gr.Column(variant="panel"):
335
+ prompt_textbox, negative_prompt_textbox = create_prompts()
336
+
337
+ with gr.Row():
338
+ with gr.Column():
339
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
340
+
341
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
342
+ default_height = 384, default_width = 672, maximum_height = 1344,
343
+ maximum_width = 1344,
344
+ )
345
+ gr.Markdown(
346
+ """
347
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
348
+ (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
349
+ """
350
+ )
351
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
352
+ create_generation_methods_and_video_length(
353
+ ["Video Generation", "Image Generation", "Long Video Generation"],
354
+ default_video_length=49,
355
+ maximum_video_length=85,
356
+ )
357
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
358
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
359
+ )
360
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
361
+
362
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
363
+
364
+ result_image, result_video, infer_progress = create_ui_outputs()
365
+
366
+ model_type.change(
367
+ fn=controller.update_model_type,
368
+ inputs=[model_type],
369
+ outputs=[]
370
+ )
371
+
372
+ def upload_generation_method(generation_method):
373
+ if generation_method == "Video Generation":
374
+ return [gr.update(visible=True, maximum=85, value=49, interactive=True), gr.update(visible=False), gr.update(visible=False)]
375
+ elif generation_method == "Image Generation":
376
+ return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
377
+ else:
378
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
379
+ generation_method.change(
380
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
381
+ )
382
+
383
+ def upload_source_method(source_method):
384
+ if source_method == "Text to Video (文本到视频)":
385
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
386
+ elif source_method == "Image to Video (图片到视频)":
387
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
388
+ elif source_method == "Video to Video (视频到视频)":
389
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
390
+ else:
391
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
392
+ source_method.change(
393
+ upload_source_method, source_method, [
394
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
395
+ validation_video, validation_video_mask, control_video
396
+ ]
397
+ )
398
+
399
+ def upload_resize_method(resize_method):
400
+ if resize_method == "Generate by":
401
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
402
+ else:
403
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
404
+ resize_method.change(
405
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
406
+ )
407
+
408
+ generate_button.click(
409
+ fn=controller.generate,
410
+ inputs=[
411
+ diffusion_transformer_dropdown,
412
+ base_model_dropdown,
413
+ lora_model_dropdown,
414
+ lora_alpha_slider,
415
+ prompt_textbox,
416
+ negative_prompt_textbox,
417
+ sampler_dropdown,
418
+ sample_step_slider,
419
+ resize_method,
420
+ width_slider,
421
+ height_slider,
422
+ base_resolution,
423
+ generation_method,
424
+ length_slider,
425
+ overlap_video_length,
426
+ partial_video_length,
427
+ cfg_scale_slider,
428
+ start_image,
429
+ end_image,
430
+ validation_video,
431
+ validation_video_mask,
432
+ control_video,
433
+ denoise_strength,
434
+ seed_textbox,
435
+ ],
436
+ outputs=[result_image, result_video, infer_progress]
437
+ )
438
+ return demo, controller
439
+
440
+ def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, ulysses_degree, ring_degree, weight_dtype, savedir_sample=None):
441
+ controller = CogVideoXFunController_Host(
442
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type,
443
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree,
444
+ config_path=None, enable_teacache=None, teacache_threshold=None, weight_dtype=weight_dtype,
445
+ savedir_sample=savedir_sample,
446
+ )
447
+
448
+ with gr.Blocks(css=css) as demo:
449
+ gr.Markdown(
450
+ """
451
+ # CogVideoX-Fun
452
+
453
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
454
+
455
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
456
+ """
457
+ )
458
+ with gr.Column(variant="panel"):
459
+ model_type = create_fake_model_type(visible=True)
460
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
461
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
462
+
463
+ with gr.Column(variant="panel"):
464
+ prompt_textbox, negative_prompt_textbox = create_prompts()
465
+
466
+ with gr.Row():
467
+ with gr.Column():
468
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
469
+
470
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
471
+ default_height = 384, default_width = 672, maximum_height = 1344,
472
+ maximum_width = 1344,
473
+ )
474
+ gr.Markdown(
475
+ """
476
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
477
+ (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
478
+ """
479
+ )
480
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
481
+ create_generation_methods_and_video_length(
482
+ ["Video Generation", "Image Generation"],
483
+ default_video_length=49,
484
+ maximum_video_length=85,
485
+ )
486
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
487
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
488
+ )
489
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
490
+
491
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
492
+
493
+ result_image, result_video, infer_progress = create_ui_outputs()
494
+
495
+ def upload_generation_method(generation_method):
496
+ if generation_method == "Video Generation":
497
+ return gr.update(visible=True, minimum=8, maximum=85, value=49, interactive=True)
498
+ elif generation_method == "Image Generation":
499
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
500
+ generation_method.change(
501
+ upload_generation_method, generation_method, [length_slider]
502
+ )
503
+
504
+ def upload_source_method(source_method):
505
+ if source_method == "Text to Video (文本到视频)":
506
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
507
+ elif source_method == "Image to Video (图片到视频)":
508
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
509
+ elif source_method == "Video to Video (视频到视频)":
510
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
511
+ else:
512
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
513
+ source_method.change(
514
+ upload_source_method, source_method, [
515
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
516
+ validation_video, validation_video_mask, control_video
517
+ ]
518
+ )
519
+
520
+ def upload_resize_method(resize_method):
521
+ if resize_method == "Generate by":
522
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
523
+ else:
524
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
525
+ resize_method.change(
526
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
527
+ )
528
+
529
+ generate_button.click(
530
+ fn=controller.generate,
531
+ inputs=[
532
+ diffusion_transformer_dropdown,
533
+ base_model_dropdown,
534
+ lora_model_dropdown,
535
+ lora_alpha_slider,
536
+ prompt_textbox,
537
+ negative_prompt_textbox,
538
+ sampler_dropdown,
539
+ sample_step_slider,
540
+ resize_method,
541
+ width_slider,
542
+ height_slider,
543
+ base_resolution,
544
+ generation_method,
545
+ length_slider,
546
+ overlap_video_length,
547
+ partial_video_length,
548
+ cfg_scale_slider,
549
+ start_image,
550
+ end_image,
551
+ validation_video,
552
+ validation_video_mask,
553
+ control_video,
554
+ denoise_strength,
555
+ seed_textbox,
556
+ ],
557
+ outputs=[result_image, result_video, infer_progress]
558
+ )
559
+ return demo, controller
560
+
561
+ def ui_client(scheduler_dict, model_name, savedir_sample=None):
562
+ controller = CogVideoXFunController_Client(scheduler_dict, savedir_sample)
563
+
564
+ with gr.Blocks(css=css) as demo:
565
+ gr.Markdown(
566
+ """
567
+ # CogVideoX-Fun
568
+
569
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
570
+
571
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
572
+ """
573
+ )
574
+ with gr.Column(variant="panel"):
575
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
576
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
577
+
578
+ with gr.Column(variant="panel"):
579
+ prompt_textbox, negative_prompt_textbox = create_prompts()
580
+
581
+ with gr.Row():
582
+ with gr.Column():
583
+ sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
584
+
585
+ resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
586
+ default_height = 384, default_width = 672, maximum_height = 1344,
587
+ maximum_width = 1344,
588
+ )
589
+ gr.Markdown(
590
+ """
591
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
592
+ (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
593
+ """
594
+ )
595
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
596
+ create_generation_methods_and_video_length(
597
+ ["Video Generation", "Image Generation"],
598
+ default_video_length=49,
599
+ maximum_video_length=85,
600
+ )
601
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
602
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"], prompt_textbox
603
+ )
604
+
605
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
606
+
607
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
608
+
609
+ result_image, result_video, infer_progress = create_ui_outputs()
610
+
611
+ def upload_generation_method(generation_method):
612
+ if generation_method == "Video Generation":
613
+ return gr.update(visible=True, minimum=5, maximum=85, value=49, interactive=True)
614
+ elif generation_method == "Image Generation":
615
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
616
+ generation_method.change(
617
+ upload_generation_method, generation_method, [length_slider]
618
+ )
619
+
620
+ def upload_source_method(source_method):
621
+ if source_method == "Text to Video (文本到视频)":
622
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
623
+ elif source_method == "Image to Video (图片到视频)":
624
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
625
+ else:
626
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
627
+ source_method.change(
628
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
629
+ )
630
+
631
+ def upload_resize_method(resize_method):
632
+ if resize_method == "Generate by":
633
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
634
+ else:
635
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
636
+ resize_method.change(
637
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
638
+ )
639
+
640
+ generate_button.click(
641
+ fn=controller.generate,
642
+ inputs=[
643
+ diffusion_transformer_dropdown,
644
+ base_model_dropdown,
645
+ lora_model_dropdown,
646
+ lora_alpha_slider,
647
+ prompt_textbox,
648
+ negative_prompt_textbox,
649
+ sampler_dropdown,
650
+ sample_step_slider,
651
+ resize_method,
652
+ width_slider,
653
+ height_slider,
654
+ base_resolution,
655
+ generation_method,
656
+ length_slider,
657
+ cfg_scale_slider,
658
+ start_image,
659
+ end_image,
660
+ validation_video,
661
+ validation_video_mask,
662
+ denoise_strength,
663
+ seed_textbox,
664
+ ],
665
+ outputs=[result_image, result_video, infer_progress]
666
+ )
667
+ return demo, controller
videox_fun/ui/ui.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+
5
+
6
+ def create_model_type(visible):
7
+ gr.Markdown(
8
+ """
9
+ ### Model Type (模型的种类,正常模型还是控制模型).
10
+ """,
11
+ visible=visible,
12
+ )
13
+ with gr.Row():
14
+ model_type = gr.Dropdown(
15
+ label="The model type of the model (模型的种类,正常模型还是控制模型)",
16
+ choices=["Inpaint", "Control"],
17
+ value="Inpaint",
18
+ visible=visible,
19
+ interactive=True,
20
+ )
21
+ return model_type
22
+
23
+ def create_fake_model_type(visible):
24
+ gr.Markdown(
25
+ """
26
+ ### Model Type (模型的种类,正常模型还是控制模型).
27
+ """,
28
+ visible=visible,
29
+ )
30
+ with gr.Row():
31
+ model_type = gr.Dropdown(
32
+ label="The model type of the model (模型的种类,正常模型还是控制模型)",
33
+ choices=["Inpaint", "Control"],
34
+ value="Inpaint",
35
+ interactive=False,
36
+ visible=visible,
37
+ )
38
+ return model_type
39
+
40
+ def create_model_checkpoints(controller, visible):
41
+ gr.Markdown(
42
+ """
43
+ ### Model checkpoints (模型路径).
44
+ """
45
+ )
46
+ with gr.Row(visible=visible):
47
+ diffusion_transformer_dropdown = gr.Dropdown(
48
+ label="Pretrained Model Path (预训练模型路径)",
49
+ choices=controller.diffusion_transformer_list,
50
+ value="none",
51
+ interactive=True,
52
+ )
53
+ diffusion_transformer_dropdown.change(
54
+ fn=controller.update_diffusion_transformer,
55
+ inputs=[diffusion_transformer_dropdown],
56
+ outputs=[diffusion_transformer_dropdown]
57
+ )
58
+
59
+ diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
60
+ def refresh_diffusion_transformer():
61
+ controller.refresh_diffusion_transformer()
62
+ return gr.update(choices=controller.diffusion_transformer_list)
63
+ diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
64
+
65
+ return diffusion_transformer_dropdown, diffusion_transformer_refresh_button
66
+
67
+ def create_fake_model_checkpoints(model_name, visible):
68
+ gr.Markdown(
69
+ """
70
+ ### Model checkpoints (模型路径).
71
+ """
72
+ )
73
+ with gr.Row(visible=visible):
74
+ diffusion_transformer_dropdown = gr.Dropdown(
75
+ label="Pretrained Model Path (预训练模型路径)",
76
+ choices=[model_name],
77
+ value=model_name,
78
+ interactive=False,
79
+ )
80
+ return diffusion_transformer_dropdown
81
+
82
+ def create_finetune_models_checkpoints(controller, visible):
83
+ with gr.Row(visible=visible):
84
+ base_model_dropdown = gr.Dropdown(
85
+ label="Select base Dreambooth model (选择基模型[非必需])",
86
+ choices=controller.personalized_model_list,
87
+ value="none",
88
+ interactive=True,
89
+ )
90
+
91
+ lora_model_dropdown = gr.Dropdown(
92
+ label="Select LoRA model (选择LoRA模型[非必需])",
93
+ choices=["none"] + controller.personalized_model_list,
94
+ value="none",
95
+ interactive=True,
96
+ )
97
+
98
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
99
+
100
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
101
+ def update_personalized_model():
102
+ controller.refresh_personalized_model()
103
+ return [
104
+ gr.update(choices=controller.personalized_model_list),
105
+ gr.update(choices=["none"] + controller.personalized_model_list)
106
+ ]
107
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
108
+
109
+ return base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button
110
+
111
+ def create_fake_finetune_models_checkpoints(visible):
112
+ with gr.Row():
113
+ base_model_dropdown = gr.Dropdown(
114
+ label="Select base Dreambooth model (选择基模型[非必需])",
115
+ choices=["none"],
116
+ value="none",
117
+ interactive=False,
118
+ visible=False
119
+ )
120
+ with gr.Column(visible=False):
121
+ gr.Markdown(
122
+ """
123
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
124
+ """
125
+ )
126
+ with gr.Row():
127
+ lora_model_dropdown = gr.Dropdown(
128
+ label="Select LoRA model",
129
+ choices=["none"],
130
+ value="none",
131
+ interactive=True,
132
+ )
133
+
134
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
135
+
136
+ return base_model_dropdown, lora_model_dropdown, lora_alpha_slider
137
+
138
+ def create_prompts(
139
+ prompt="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
140
+ negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
141
+ ):
142
+ gr.Markdown(
143
+ """
144
+ ### Configs for Generation (生成参数配置).
145
+ """
146
+ )
147
+
148
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value=prompt)
149
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value=negative_prompt)
150
+ return prompt_textbox, negative_prompt_textbox
151
+
152
+ def create_samplers(controller, maximum_step=100):
153
+ with gr.Row():
154
+ sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
155
+ sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=maximum_step, step=1)
156
+
157
+ return sampler_dropdown, sample_step_slider
158
+
159
+ def create_height_width(default_height, default_width, maximum_height, maximum_width):
160
+ resize_method = gr.Radio(
161
+ ["Generate by", "Resize according to Reference"],
162
+ value="Generate by",
163
+ show_label=False,
164
+ )
165
+ width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16)
166
+ height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16)
167
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
168
+
169
+ return resize_method, width_slider, height_slider, base_resolution
170
+
171
+ def create_fake_height_width(default_height, default_width, maximum_height, maximum_width):
172
+ resize_method = gr.Radio(
173
+ ["Generate by", "Resize according to Reference"],
174
+ value="Generate by",
175
+ show_label=False,
176
+ )
177
+ width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16, interactive=False)
178
+ height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16, interactive=False)
179
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
180
+
181
+ return resize_method, width_slider, height_slider, base_resolution
182
+
183
+ def create_generation_methods_and_video_length(
184
+ generation_method_options,
185
+ default_video_length,
186
+ maximum_video_length
187
+ ):
188
+ with gr.Group():
189
+ generation_method = gr.Radio(
190
+ generation_method_options,
191
+ value="Video Generation",
192
+ show_label=False,
193
+ )
194
+ with gr.Row():
195
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=default_video_length, minimum=1, maximum=maximum_video_length, step=4)
196
+ overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
197
+ partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=maximum_video_length, step=4, visible=False)
198
+
199
+ return generation_method, length_slider, overlap_video_length, partial_video_length
200
+
201
+ def create_generation_method(source_method_options, prompt_textbox, support_end_image=True):
202
+ source_method = gr.Radio(
203
+ source_method_options,
204
+ value="Text to Video (文本到视频)",
205
+ show_label=False,
206
+ )
207
+ with gr.Column(visible = False) as image_to_video_col:
208
+ start_image = gr.Image(
209
+ label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True,
210
+ elem_id="i2v_start", sources="upload", type="filepath",
211
+ )
212
+
213
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
214
+ def select_template(evt: gr.SelectData):
215
+ text = {
216
+ "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
217
+ "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
218
+ "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
219
+ "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
220
+ "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
221
+ }[template_gallery_path[evt.index]]
222
+ return template_gallery_path[evt.index], text
223
+
224
+ template_gallery = gr.Gallery(
225
+ template_gallery_path,
226
+ columns=5, rows=1,
227
+ height=140,
228
+ allow_preview=False,
229
+ container=False,
230
+ label="Template Examples",
231
+ )
232
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
233
+
234
+ with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False, visible=support_end_image):
235
+ end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
236
+
237
+ with gr.Column(visible = False) as video_to_video_col:
238
+ with gr.Row():
239
+ validation_video = gr.Video(
240
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
241
+ elem_id="v2v", sources="upload",
242
+ )
243
+ with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
244
+ gr.Markdown(
245
+ """
246
+ - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
247
+ (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
248
+ """
249
+ )
250
+ validation_video_mask = gr.Image(
251
+ label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
252
+ show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
253
+ )
254
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
255
+
256
+ with gr.Column(visible = False) as control_video_col:
257
+ gr.Markdown(
258
+ """
259
+ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
260
+ """
261
+ )
262
+ control_video = gr.Video(
263
+ label="The control video (用于提供控制信号的video)", show_label=True,
264
+ elem_id="v2v_control", sources="upload",
265
+ )
266
+ return image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video
267
+
268
+ def create_cfg_and_seedbox(gradio_version_is_above_4):
269
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
270
+
271
+ with gr.Row():
272
+ seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
273
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
274
+ seed_button.click(
275
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
276
+ inputs=[],
277
+ outputs=[seed_textbox]
278
+ )
279
+ return cfg_scale_slider, seed_textbox, seed_button
280
+
281
+ def create_ui_outputs():
282
+ with gr.Column():
283
+ result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
284
+ result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
285
+ infer_progress = gr.Textbox(
286
+ label="Generation Info (生成信息)",
287
+ value="No task currently",
288
+ interactive=False
289
+ )
290
+ return result_image, result_video, infer_progress
videox_fun/ui/wan_fun_ui.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import os
4
+ import random
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
+ from safetensors import safe_open
13
+
14
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
16
+ WanT5EncoderModel, WanTransformer3DModel)
17
+ from ..models.cache_utils import get_teacache_coefficients
18
+ from ..pipeline import WanFunInpaintPipeline, WanFunPipeline, WanFunControlPipeline
19
+ from ..utils.fp8_optimization import (convert_model_weight_to_float8,
20
+ convert_weight_dtype_wrapper,
21
+ replace_parameters_by_name)
22
+ from ..utils.lora_utils import merge_lora, unmerge_lora
23
+ from ..utils.utils import (filter_kwargs, get_image_to_video_latent,
24
+ get_video_to_video_latent, save_videos_grid)
25
+ from .controller import (Fun_Controller, Fun_Controller_Client,
26
+ all_cheduler_dict, css, ddpm_scheduler_dict,
27
+ flow_scheduler_dict, gradio_version,
28
+ gradio_version_is_above_4)
29
+ from .ui import (create_cfg_and_seedbox,
30
+ create_fake_finetune_models_checkpoints,
31
+ create_fake_height_width, create_fake_model_checkpoints,
32
+ create_fake_model_type, create_finetune_models_checkpoints,
33
+ create_generation_method,
34
+ create_generation_methods_and_video_length,
35
+ create_height_width, create_model_checkpoints,
36
+ create_model_type, create_prompts, create_samplers,
37
+ create_ui_outputs)
38
+
39
+
40
+ class Wan_Fun_Controller(Fun_Controller):
41
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
42
+ print("Update diffusion transformer")
43
+ self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
44
+ if diffusion_transformer_dropdown == "none":
45
+ return gr.update()
46
+ self.vae = AutoencoderKLWan.from_pretrained(
47
+ os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')),
48
+ additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']),
49
+ ).to(self.weight_dtype)
50
+
51
+ # Get Transformer
52
+ self.transformer = WanTransformer3DModel.from_pretrained(
53
+ os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
54
+ transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
55
+ low_cpu_mem_usage=True,
56
+ torch_dtype=self.weight_dtype,
57
+ )
58
+
59
+ # Get Tokenizer
60
+ self.tokenizer = AutoTokenizer.from_pretrained(
61
+ os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
62
+ )
63
+
64
+ # Get Text encoder
65
+ self.text_encoder = WanT5EncoderModel.from_pretrained(
66
+ os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
67
+ additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']),
68
+ low_cpu_mem_usage=True,
69
+ torch_dtype=self.weight_dtype,
70
+ )
71
+ self.text_encoder = self.text_encoder.eval()
72
+
73
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
74
+ # Get Clip Image Encoder
75
+ self.clip_image_encoder = CLIPModel.from_pretrained(
76
+ os.path.join(diffusion_transformer_dropdown, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
77
+ ).to(self.weight_dtype)
78
+ self.clip_image_encoder = self.clip_image_encoder.eval()
79
+ else:
80
+ self.clip_image_encoder = None
81
+
82
+ Choosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]]
83
+ self.scheduler = Choosen_Scheduler(
84
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs']))
85
+ )
86
+
87
+ # Get pipeline
88
+ if self.model_type == "Inpaint":
89
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
90
+ self.pipeline = WanFunInpaintPipeline(
91
+ vae=self.vae,
92
+ tokenizer=self.tokenizer,
93
+ text_encoder=self.text_encoder,
94
+ transformer=self.transformer,
95
+ scheduler=self.scheduler,
96
+ clip_image_encoder=self.clip_image_encoder,
97
+ )
98
+ else:
99
+ self.pipeline = WanFunPipeline(
100
+ vae=self.vae,
101
+ tokenizer=self.tokenizer,
102
+ text_encoder=self.text_encoder,
103
+ transformer=self.transformer,
104
+ scheduler=self.scheduler,
105
+ )
106
+ else:
107
+ self.pipeline = WanFunControlPipeline(
108
+ vae=self.vae,
109
+ tokenizer=self.tokenizer,
110
+ text_encoder=self.text_encoder,
111
+ transformer=self.transformer,
112
+ scheduler=self.scheduler,
113
+ clip_image_encoder=self.clip_image_encoder,
114
+ )
115
+
116
+ if self.ulysses_degree > 1 or self.ring_degree > 1:
117
+ self.transformer.enable_multi_gpus_inference()
118
+
119
+ if self.GPU_memory_mode == "sequential_cpu_offload":
120
+ replace_parameters_by_name(self.transformer, ["modulation",], device=self.device)
121
+ self.transformer.freqs = self.transformer.freqs.to(device=self.device)
122
+ self.pipeline.enable_sequential_cpu_offload(device=self.device)
123
+ elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
124
+ convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",])
125
+ convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
126
+ self.pipeline.enable_model_cpu_offload(device=self.device)
127
+ elif self.GPU_memory_mode == "model_cpu_offload":
128
+ self.pipeline.enable_model_cpu_offload(device=self.device)
129
+ else:
130
+ self.pipeline.to(self.device)
131
+ print("Update diffusion transformer done")
132
+ return gr.update()
133
+
134
+ def generate(
135
+ self,
136
+ diffusion_transformer_dropdown,
137
+ base_model_dropdown,
138
+ lora_model_dropdown,
139
+ lora_alpha_slider,
140
+ prompt_textbox,
141
+ negative_prompt_textbox,
142
+ sampler_dropdown,
143
+ sample_step_slider,
144
+ resize_method,
145
+ width_slider,
146
+ height_slider,
147
+ base_resolution,
148
+ generation_method,
149
+ length_slider,
150
+ overlap_video_length,
151
+ partial_video_length,
152
+ cfg_scale_slider,
153
+ start_image,
154
+ end_image,
155
+ validation_video,
156
+ validation_video_mask,
157
+ control_video,
158
+ denoise_strength,
159
+ seed_textbox,
160
+ is_api = False,
161
+ ):
162
+ self.clear_cache()
163
+
164
+ self.input_check(
165
+ resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
166
+ )
167
+ is_image = True if generation_method == "Image Generation" else False
168
+
169
+ if self.base_model_path != base_model_dropdown:
170
+ self.update_base_model(base_model_dropdown)
171
+
172
+ if self.lora_model_path != lora_model_dropdown:
173
+ self.update_lora_model(lora_model_dropdown)
174
+
175
+ self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
176
+
177
+ if resize_method == "Resize according to Reference":
178
+ height_slider, width_slider = self.get_height_width_from_reference(
179
+ base_resolution, start_image, validation_video, control_video,
180
+ )
181
+ if self.lora_model_path != "none":
182
+ # lora part
183
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
184
+
185
+ coefficients = get_teacache_coefficients(self.base_model_path) if self.enable_teacache else None
186
+ if coefficients is not None:
187
+ print(f"Enable TeaCache with threshold {self.teacache_threshold} and skip the first {self.num_skip_start_steps} steps.")
188
+ self.pipeline.transformer.enable_teacache(
189
+ coefficients, sample_step_slider, self.teacache_threshold, num_skip_start_steps=self.num_skip_start_steps, offload=self.teacache_offload
190
+ )
191
+
192
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
193
+ else: seed_textbox = np.random.randint(0, 1e10)
194
+ generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
195
+
196
+ if self.enable_riflex:
197
+ latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1
198
+ self.pipeline.transformer.enable_riflex(k = self.riflex_k, L_test = latent_frames if not is_image else 1)
199
+
200
+ try:
201
+ if self.model_type == "Inpaint":
202
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
203
+ if validation_video is not None:
204
+ input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=16)
205
+ else:
206
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
207
+
208
+ sample = self.pipeline(
209
+ prompt_textbox,
210
+ negative_prompt = negative_prompt_textbox,
211
+ num_inference_steps = sample_step_slider,
212
+ guidance_scale = cfg_scale_slider,
213
+ width = width_slider,
214
+ height = height_slider,
215
+ num_frames = length_slider if not is_image else 1,
216
+ generator = generator,
217
+
218
+ video = input_video,
219
+ mask_video = input_video_mask,
220
+ clip_image = clip_image
221
+ ).videos
222
+ else:
223
+ sample = self.pipeline(
224
+ prompt_textbox,
225
+ negative_prompt = negative_prompt_textbox,
226
+ num_inference_steps = sample_step_slider,
227
+ guidance_scale = cfg_scale_slider,
228
+ width = width_slider,
229
+ height = height_slider,
230
+ num_frames = length_slider if not is_image else 1,
231
+ generator = generator
232
+ ).videos
233
+ else:
234
+ input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=16)
235
+
236
+ sample = self.pipeline(
237
+ prompt_textbox,
238
+ negative_prompt = negative_prompt_textbox,
239
+ num_inference_steps = sample_step_slider,
240
+ guidance_scale = cfg_scale_slider,
241
+ width = width_slider,
242
+ height = height_slider,
243
+ num_frames = length_slider if not is_image else 1,
244
+ generator = generator,
245
+
246
+ control_video = input_video,
247
+ ).videos
248
+ except Exception as e:
249
+ self.clear_cache()
250
+ if self.lora_model_path != "none":
251
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
252
+ if is_api:
253
+ return "", f"Error. error information is {str(e)}"
254
+ else:
255
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
256
+
257
+ self.clear_cache()
258
+ # lora part
259
+ if self.lora_model_path != "none":
260
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
261
+
262
+ save_sample_path = self.save_outputs(
263
+ is_image, length_slider, sample, fps=16
264
+ )
265
+
266
+ if is_image or length_slider == 1:
267
+ if is_api:
268
+ return save_sample_path, "Success"
269
+ else:
270
+ if gradio_version_is_above_4:
271
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
272
+ else:
273
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
274
+ else:
275
+ if is_api:
276
+ return save_sample_path, "Success"
277
+ else:
278
+ if gradio_version_is_above_4:
279
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
280
+ else:
281
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
282
+
283
+ Wan_Fun_Controller_Host = Wan_Fun_Controller
284
+ Wan_Fun_Controller_Client = Fun_Controller_Client
285
+
286
+ def ui(GPU_memory_mode, scheduler_dict, config_path, ulysses_degree, ring_degree, enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload, enable_riflex, riflex_k, weight_dtype, savedir_sample=None):
287
+ controller = Wan_Fun_Controller(
288
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
289
+ config_path=config_path, ulysses_degree=ulysses_degree, ring_degree=ring_degree,
290
+ enable_teacache=enable_teacache, teacache_threshold=teacache_threshold,
291
+ num_skip_start_steps=num_skip_start_steps, teacache_offload=teacache_offload,
292
+ enable_riflex=enable_riflex, riflex_k=riflex_k, weight_dtype=weight_dtype,
293
+ savedir_sample=savedir_sample,
294
+ )
295
+
296
+ with gr.Blocks(css=css) as demo:
297
+ gr.Markdown(
298
+ """
299
+ # Wan-Fun:
300
+
301
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
302
+
303
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
304
+ """
305
+ )
306
+ with gr.Column(variant="panel"):
307
+ model_type = create_model_type(visible=True)
308
+ diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
309
+ create_model_checkpoints(controller, visible=True)
310
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
311
+ create_finetune_models_checkpoints(controller, visible=True)
312
+
313
+ with gr.Column(variant="panel"):
314
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
315
+
316
+ with gr.Row():
317
+ with gr.Column():
318
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
319
+
320
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
321
+ default_height = 480, default_width = 832, maximum_height = 1344,
322
+ maximum_width = 1344,
323
+ )
324
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
325
+ create_generation_methods_and_video_length(
326
+ ["Video Generation", "Image Generation"],
327
+ default_video_length=81,
328
+ maximum_video_length=81,
329
+ )
330
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
331
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
332
+ )
333
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
334
+
335
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
336
+
337
+ result_image, result_video, infer_progress = create_ui_outputs()
338
+
339
+ model_type.change(
340
+ fn=controller.update_model_type,
341
+ inputs=[model_type],
342
+ outputs=[]
343
+ )
344
+
345
+ def upload_generation_method(generation_method):
346
+ if generation_method == "Video Generation":
347
+ return [gr.update(visible=True, maximum=81, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)]
348
+ elif generation_method == "Image Generation":
349
+ return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
350
+ else:
351
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
352
+ generation_method.change(
353
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
354
+ )
355
+
356
+ def upload_source_method(source_method):
357
+ if source_method == "Text to Video (文本到视频)":
358
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
359
+ elif source_method == "Image to Video (图片到视频)":
360
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
361
+ elif source_method == "Video to Video (视频到视频)":
362
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
363
+ else:
364
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
365
+ source_method.change(
366
+ upload_source_method, source_method, [
367
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
368
+ validation_video, validation_video_mask, control_video
369
+ ]
370
+ )
371
+
372
+ def upload_resize_method(resize_method):
373
+ if resize_method == "Generate by":
374
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
375
+ else:
376
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
377
+ resize_method.change(
378
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
379
+ )
380
+
381
+ generate_button.click(
382
+ fn=controller.generate,
383
+ inputs=[
384
+ diffusion_transformer_dropdown,
385
+ base_model_dropdown,
386
+ lora_model_dropdown,
387
+ lora_alpha_slider,
388
+ prompt_textbox,
389
+ negative_prompt_textbox,
390
+ sampler_dropdown,
391
+ sample_step_slider,
392
+ resize_method,
393
+ width_slider,
394
+ height_slider,
395
+ base_resolution,
396
+ generation_method,
397
+ length_slider,
398
+ overlap_video_length,
399
+ partial_video_length,
400
+ cfg_scale_slider,
401
+ start_image,
402
+ end_image,
403
+ validation_video,
404
+ validation_video_mask,
405
+ control_video,
406
+ denoise_strength,
407
+ seed_textbox,
408
+ ],
409
+ outputs=[result_image, result_video, infer_progress]
410
+ )
411
+ return demo, controller
412
+
413
+ def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, ulysses_degree, ring_degree, enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload, enable_riflex, riflex_k, weight_dtype, savedir_sample=None):
414
+ controller = Wan_Fun_Controller_Host(
415
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type,
416
+ config_path=config_path, ulysses_degree=ulysses_degree, ring_degree=ring_degree,
417
+ enable_teacache=enable_teacache, teacache_threshold=teacache_threshold,
418
+ num_skip_start_steps=num_skip_start_steps, teacache_offload=teacache_offload,
419
+ enable_riflex=enable_riflex, riflex_k=riflex_k, weight_dtype=weight_dtype,
420
+ savedir_sample=savedir_sample,
421
+ )
422
+
423
+ with gr.Blocks(css=css) as demo:
424
+ gr.Markdown(
425
+ """
426
+ # Wan-Fun:
427
+
428
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
429
+
430
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
431
+ """
432
+ )
433
+ with gr.Column(variant="panel"):
434
+ model_type = create_fake_model_type(visible=True)
435
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
436
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
437
+
438
+ with gr.Column(variant="panel"):
439
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
440
+
441
+ with gr.Row():
442
+ with gr.Column():
443
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
444
+
445
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
446
+ default_height = 480, default_width = 832, maximum_height = 1344,
447
+ maximum_width = 1344,
448
+ )
449
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
450
+ create_generation_methods_and_video_length(
451
+ ["Video Generation", "Image Generation"],
452
+ default_video_length=81,
453
+ maximum_video_length=81,
454
+ )
455
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
456
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
457
+ )
458
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
459
+
460
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
461
+
462
+ result_image, result_video, infer_progress = create_ui_outputs()
463
+
464
+ def upload_generation_method(generation_method):
465
+ if generation_method == "Video Generation":
466
+ return gr.update(visible=True, minimum=1, maximum=81, value=81, interactive=True)
467
+ elif generation_method == "Image Generation":
468
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
469
+ generation_method.change(
470
+ upload_generation_method, generation_method, [length_slider]
471
+ )
472
+
473
+ def upload_source_method(source_method):
474
+ if source_method == "Text to Video (文本到视频)":
475
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
476
+ elif source_method == "Image to Video (图片到视频)":
477
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
478
+ elif source_method == "Video to Video (视频到视频)":
479
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
480
+ else:
481
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
482
+ source_method.change(
483
+ upload_source_method, source_method, [
484
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
485
+ validation_video, validation_video_mask, control_video
486
+ ]
487
+ )
488
+
489
+ def upload_resize_method(resize_method):
490
+ if resize_method == "Generate by":
491
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
492
+ else:
493
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
494
+ resize_method.change(
495
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
496
+ )
497
+
498
+ generate_button.click(
499
+ fn=controller.generate,
500
+ inputs=[
501
+ diffusion_transformer_dropdown,
502
+ base_model_dropdown,
503
+ lora_model_dropdown,
504
+ lora_alpha_slider,
505
+ prompt_textbox,
506
+ negative_prompt_textbox,
507
+ sampler_dropdown,
508
+ sample_step_slider,
509
+ resize_method,
510
+ width_slider,
511
+ height_slider,
512
+ base_resolution,
513
+ generation_method,
514
+ length_slider,
515
+ overlap_video_length,
516
+ partial_video_length,
517
+ cfg_scale_slider,
518
+ start_image,
519
+ end_image,
520
+ validation_video,
521
+ validation_video_mask,
522
+ control_video,
523
+ denoise_strength,
524
+ seed_textbox,
525
+ ],
526
+ outputs=[result_image, result_video, infer_progress]
527
+ )
528
+ return demo, controller
529
+
530
+ def ui_client(scheduler_dict, model_name, savedir_sample=None):
531
+ controller = Wan_Fun_Controller_Client(scheduler_dict, savedir_sample)
532
+
533
+ with gr.Blocks(css=css) as demo:
534
+ gr.Markdown(
535
+ """
536
+ # Wan-Fun:
537
+
538
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
539
+
540
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
541
+ """
542
+ )
543
+ with gr.Column(variant="panel"):
544
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
545
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
546
+
547
+ with gr.Column(variant="panel"):
548
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
549
+
550
+ with gr.Row():
551
+ with gr.Column():
552
+ sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
553
+
554
+ resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
555
+ default_height = 480, default_width = 832, maximum_height = 1344,
556
+ maximum_width = 1344,
557
+ )
558
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
559
+ create_generation_methods_and_video_length(
560
+ ["Video Generation", "Image Generation"],
561
+ default_video_length=81,
562
+ maximum_video_length=81,
563
+ )
564
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
565
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
566
+ )
567
+
568
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
569
+
570
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
571
+
572
+ result_image, result_video, infer_progress = create_ui_outputs()
573
+
574
+ def upload_generation_method(generation_method):
575
+ if generation_method == "Video Generation":
576
+ return gr.update(visible=True, minimum=5, maximum=81, value=49, interactive=True)
577
+ elif generation_method == "Image Generation":
578
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
579
+ generation_method.change(
580
+ upload_generation_method, generation_method, [length_slider]
581
+ )
582
+
583
+ def upload_source_method(source_method):
584
+ if source_method == "Text to Video (文本到视频)":
585
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
586
+ elif source_method == "Image to Video (图片到视频)":
587
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
588
+ else:
589
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
590
+ source_method.change(
591
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
592
+ )
593
+
594
+ def upload_resize_method(resize_method):
595
+ if resize_method == "Generate by":
596
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
597
+ else:
598
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
599
+ resize_method.change(
600
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
601
+ )
602
+
603
+ generate_button.click(
604
+ fn=controller.generate,
605
+ inputs=[
606
+ diffusion_transformer_dropdown,
607
+ base_model_dropdown,
608
+ lora_model_dropdown,
609
+ lora_alpha_slider,
610
+ prompt_textbox,
611
+ negative_prompt_textbox,
612
+ sampler_dropdown,
613
+ sample_step_slider,
614
+ resize_method,
615
+ width_slider,
616
+ height_slider,
617
+ base_resolution,
618
+ generation_method,
619
+ length_slider,
620
+ cfg_scale_slider,
621
+ start_image,
622
+ end_image,
623
+ validation_video,
624
+ validation_video_mask,
625
+ denoise_strength,
626
+ seed_textbox,
627
+ ],
628
+ outputs=[result_image, result_video, infer_progress]
629
+ )
630
+ return demo, controller
videox_fun/utils/__init__.py ADDED
File without changes
videox_fun/utils/discrete_sampler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
2
+ """
3
+ import torch
4
+
5
+ class DiscreteSampling:
6
+ def __init__(self, num_idx, uniform_sampling=False):
7
+ self.num_idx = num_idx
8
+ self.uniform_sampling = uniform_sampling
9
+ self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
10
+
11
+ if self.is_distributed and self.uniform_sampling:
12
+ world_size = torch.distributed.get_world_size()
13
+ self.rank = torch.distributed.get_rank()
14
+
15
+ i = 1
16
+ while True:
17
+ if world_size % i != 0 or num_idx % (world_size // i) != 0:
18
+ i += 1
19
+ else:
20
+ self.group_num = world_size // i
21
+ break
22
+ assert self.group_num > 0
23
+ assert world_size % self.group_num == 0
24
+ # the number of rank in one group
25
+ self.group_width = world_size // self.group_num
26
+ self.sigma_interval = self.num_idx // self.group_num
27
+ print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
28
+ self.rank, world_size, self.group_num,
29
+ self.group_width, self.sigma_interval))
30
+
31
+ def __call__(self, n_samples, generator=None, device=None):
32
+ if self.is_distributed and self.uniform_sampling:
33
+ group_index = self.rank // self.group_width
34
+ idx = torch.randint(
35
+ group_index * self.sigma_interval,
36
+ (group_index + 1) * self.sigma_interval,
37
+ (n_samples,),
38
+ generator=generator, device=device,
39
+ )
40
+ print('proc[%d] idx=%s' % (self.rank, idx))
41
+ else:
42
+ idx = torch.randint(
43
+ 0, self.num_idx, (n_samples,),
44
+ generator=generator, device=device,
45
+ )
46
+ return idx
videox_fun/utils/fp8_optimization.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/kijai/ComfyUI-MochiWrapper
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
7
+ weight_dtype = cls.weight.dtype
8
+ cls.to(origin_dtype)
9
+
10
+ # Convert all inputs to the original dtype
11
+ inputs = [input.to(origin_dtype) for input in inputs]
12
+ out = cls.original_forward(*inputs, **kwargs)
13
+
14
+ cls.to(weight_dtype)
15
+ return out
16
+
17
+ def replace_parameters_by_name(module, name_keywords, device):
18
+ from torch import nn
19
+ for name, param in list(module.named_parameters(recurse=False)):
20
+ if any(keyword in name for keyword in name_keywords):
21
+ if isinstance(param, nn.Parameter):
22
+ tensor = param.data
23
+ delattr(module, name)
24
+ setattr(module, name, tensor.to(device=device))
25
+ for child_name, child_module in module.named_children():
26
+ replace_parameters_by_name(child_module, name_keywords, device)
27
+
28
+ def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']):
29
+ for name, module in model.named_modules():
30
+ flag = False
31
+ for _exclude_module_name in exclude_module_name:
32
+ if _exclude_module_name in name:
33
+ flag = True
34
+ if flag:
35
+ continue
36
+ for param_name, param in module.named_parameters():
37
+ flag = False
38
+ for _exclude_module_name in exclude_module_name:
39
+ if _exclude_module_name in param_name:
40
+ flag = True
41
+ if flag:
42
+ continue
43
+ param.data = param.data.to(torch.float8_e4m3fn)
44
+
45
+ def convert_weight_dtype_wrapper(module, origin_dtype):
46
+ for name, module in module.named_modules():
47
+ if name == "" or "embed_tokens" in name:
48
+ continue
49
+ original_forward = module.forward
50
+ if hasattr(module, "weight") and module.weight is not None:
51
+ setattr(module, "original_forward", original_forward)
52
+ setattr(
53
+ module,
54
+ "forward",
55
+ lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
56
+ )
videox_fun/utils/lora_utils.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss
6
+
7
+ import hashlib
8
+ import math
9
+ import os
10
+ from collections import defaultdict
11
+ from io import BytesIO
12
+ from typing import List, Optional, Type, Union
13
+
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
18
+ from safetensors.torch import load_file
19
+ from transformers import T5EncoderModel
20
+
21
+
22
+ class LoRAModule(torch.nn.Module):
23
+ """
24
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lora_name,
30
+ org_module: torch.nn.Module,
31
+ multiplier=1.0,
32
+ lora_dim=4,
33
+ alpha=1,
34
+ dropout=None,
35
+ rank_dropout=None,
36
+ module_dropout=None,
37
+ ):
38
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
39
+ super().__init__()
40
+ self.lora_name = lora_name
41
+
42
+ if org_module.__class__.__name__ == "Conv2d":
43
+ in_dim = org_module.in_channels
44
+ out_dim = org_module.out_channels
45
+ else:
46
+ in_dim = org_module.in_features
47
+ out_dim = org_module.out_features
48
+
49
+ self.lora_dim = lora_dim
50
+ if org_module.__class__.__name__ == "Conv2d":
51
+ kernel_size = org_module.kernel_size
52
+ stride = org_module.stride
53
+ padding = org_module.padding
54
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
55
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
56
+ else:
57
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
58
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
59
+
60
+ if type(alpha) == torch.Tensor:
61
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
62
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
63
+ self.scale = alpha / self.lora_dim
64
+ self.register_buffer("alpha", torch.tensor(alpha))
65
+
66
+ # same as microsoft's
67
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
68
+ torch.nn.init.zeros_(self.lora_up.weight)
69
+
70
+ self.multiplier = multiplier
71
+ self.org_module = org_module # remove in applying
72
+ self.dropout = dropout
73
+ self.rank_dropout = rank_dropout
74
+ self.module_dropout = module_dropout
75
+
76
+ def apply_to(self):
77
+ self.org_forward = self.org_module.forward
78
+ self.org_module.forward = self.forward
79
+ del self.org_module
80
+
81
+ def forward(self, x, *args, **kwargs):
82
+ weight_dtype = x.dtype
83
+ org_forwarded = self.org_forward(x)
84
+
85
+ # module dropout
86
+ if self.module_dropout is not None and self.training:
87
+ if torch.rand(1) < self.module_dropout:
88
+ return org_forwarded
89
+
90
+ lx = self.lora_down(x.to(self.lora_down.weight.dtype))
91
+
92
+ # normal dropout
93
+ if self.dropout is not None and self.training:
94
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
95
+
96
+ # rank dropout
97
+ if self.rank_dropout is not None and self.training:
98
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
99
+ if len(lx.size()) == 3:
100
+ mask = mask.unsqueeze(1) # for Text Encoder
101
+ elif len(lx.size()) == 4:
102
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
103
+ lx = lx * mask
104
+
105
+ # scaling for rank dropout: treat as if the rank is changed
106
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
107
+ else:
108
+ scale = self.scale
109
+
110
+ lx = self.lora_up(lx)
111
+
112
+ return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
113
+
114
+
115
+ def addnet_hash_legacy(b):
116
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
117
+ m = hashlib.sha256()
118
+
119
+ b.seek(0x100000)
120
+ m.update(b.read(0x10000))
121
+ return m.hexdigest()[0:8]
122
+
123
+
124
+ def addnet_hash_safetensors(b):
125
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
126
+ hash_sha256 = hashlib.sha256()
127
+ blksize = 1024 * 1024
128
+
129
+ b.seek(0)
130
+ header = b.read(8)
131
+ n = int.from_bytes(header, "little")
132
+
133
+ offset = n + 8
134
+ b.seek(offset)
135
+ for chunk in iter(lambda: b.read(blksize), b""):
136
+ hash_sha256.update(chunk)
137
+
138
+ return hash_sha256.hexdigest()
139
+
140
+
141
+ def precalculate_safetensors_hashes(tensors, metadata):
142
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
143
+ save time on indexing the model later."""
144
+
145
+ # Because writing user metadata to the file can change the result of
146
+ # sd_models.model_hash(), only retain the training metadata for purposes of
147
+ # calculating the hash, as they are meant to be immutable
148
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
149
+
150
+ bytes = safetensors.torch.save(tensors, metadata)
151
+ b = BytesIO(bytes)
152
+
153
+ model_hash = addnet_hash_safetensors(b)
154
+ legacy_hash = addnet_hash_legacy(b)
155
+ return model_hash, legacy_hash
156
+
157
+
158
+ class LoRANetwork(torch.nn.Module):
159
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel", "WanTransformer3DModel"]
160
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"]
161
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
162
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
163
+ def __init__(
164
+ self,
165
+ text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
166
+ unet,
167
+ multiplier: float = 1.0,
168
+ lora_dim: int = 4,
169
+ alpha: float = 1,
170
+ dropout: Optional[float] = None,
171
+ module_class: Type[object] = LoRAModule,
172
+ skip_name: str = None,
173
+ varbose: Optional[bool] = False,
174
+ ) -> None:
175
+ super().__init__()
176
+ self.multiplier = multiplier
177
+
178
+ self.lora_dim = lora_dim
179
+ self.alpha = alpha
180
+ self.dropout = dropout
181
+
182
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
183
+ print(f"neuron dropout: p={self.dropout}")
184
+
185
+ # create module instances
186
+ def create_modules(
187
+ is_unet: bool,
188
+ root_module: torch.nn.Module,
189
+ target_replace_modules: List[torch.nn.Module],
190
+ ) -> List[LoRAModule]:
191
+ prefix = (
192
+ self.LORA_PREFIX_TRANSFORMER
193
+ if is_unet
194
+ else self.LORA_PREFIX_TEXT_ENCODER
195
+ )
196
+ loras = []
197
+ skipped = []
198
+ for name, module in root_module.named_modules():
199
+ if module.__class__.__name__ in target_replace_modules:
200
+ for child_name, child_module in module.named_modules():
201
+ is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
202
+ is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
203
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
204
+
205
+ if skip_name is not None and skip_name in child_name:
206
+ continue
207
+
208
+ if is_linear or is_conv2d:
209
+ lora_name = prefix + "." + name + "." + child_name
210
+ lora_name = lora_name.replace(".", "_")
211
+
212
+ dim = None
213
+ alpha = None
214
+
215
+ if is_linear or is_conv2d_1x1:
216
+ dim = self.lora_dim
217
+ alpha = self.alpha
218
+
219
+ if dim is None or dim == 0:
220
+ if is_linear or is_conv2d_1x1:
221
+ skipped.append(lora_name)
222
+ continue
223
+
224
+ lora = module_class(
225
+ lora_name,
226
+ child_module,
227
+ self.multiplier,
228
+ dim,
229
+ alpha,
230
+ dropout=dropout,
231
+ )
232
+ loras.append(lora)
233
+ return loras, skipped
234
+
235
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
236
+
237
+ self.text_encoder_loras = []
238
+ skipped_te = []
239
+ for i, text_encoder in enumerate(text_encoders):
240
+ if text_encoder is not None:
241
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
242
+ self.text_encoder_loras.extend(text_encoder_loras)
243
+ skipped_te += skipped
244
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
245
+
246
+ self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
247
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
248
+
249
+ # assertion
250
+ names = set()
251
+ for lora in self.text_encoder_loras + self.unet_loras:
252
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
253
+ names.add(lora.lora_name)
254
+
255
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
256
+ if apply_text_encoder:
257
+ print("enable LoRA for text encoder")
258
+ else:
259
+ self.text_encoder_loras = []
260
+
261
+ if apply_unet:
262
+ print("enable LoRA for U-Net")
263
+ else:
264
+ self.unet_loras = []
265
+
266
+ for lora in self.text_encoder_loras + self.unet_loras:
267
+ lora.apply_to()
268
+ self.add_module(lora.lora_name, lora)
269
+
270
+ def set_multiplier(self, multiplier):
271
+ self.multiplier = multiplier
272
+ for lora in self.text_encoder_loras + self.unet_loras:
273
+ lora.multiplier = self.multiplier
274
+
275
+ def load_weights(self, file):
276
+ if os.path.splitext(file)[1] == ".safetensors":
277
+ from safetensors.torch import load_file
278
+
279
+ weights_sd = load_file(file)
280
+ else:
281
+ weights_sd = torch.load(file, map_location="cpu")
282
+ info = self.load_state_dict(weights_sd, False)
283
+ return info
284
+
285
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
286
+ self.requires_grad_(True)
287
+ all_params = []
288
+
289
+ def enumerate_params(loras):
290
+ params = []
291
+ for lora in loras:
292
+ params.extend(lora.parameters())
293
+ return params
294
+
295
+ if self.text_encoder_loras:
296
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
297
+ if text_encoder_lr is not None:
298
+ param_data["lr"] = text_encoder_lr
299
+ all_params.append(param_data)
300
+
301
+ if self.unet_loras:
302
+ param_data = {"params": enumerate_params(self.unet_loras)}
303
+ if unet_lr is not None:
304
+ param_data["lr"] = unet_lr
305
+ all_params.append(param_data)
306
+
307
+ return all_params
308
+
309
+ def enable_gradient_checkpointing(self):
310
+ pass
311
+
312
+ def get_trainable_params(self):
313
+ return self.parameters()
314
+
315
+ def save_weights(self, file, dtype, metadata):
316
+ if metadata is not None and len(metadata) == 0:
317
+ metadata = None
318
+
319
+ state_dict = self.state_dict()
320
+
321
+ if dtype is not None:
322
+ for key in list(state_dict.keys()):
323
+ v = state_dict[key]
324
+ v = v.detach().clone().to("cpu").to(dtype)
325
+ state_dict[key] = v
326
+
327
+ if os.path.splitext(file)[1] == ".safetensors":
328
+ from safetensors.torch import save_file
329
+
330
+ # Precalculate model hashes to save time on indexing
331
+ if metadata is None:
332
+ metadata = {}
333
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
334
+ metadata["sshs_model_hash"] = model_hash
335
+ metadata["sshs_legacy_hash"] = legacy_hash
336
+
337
+ save_file(state_dict, file, metadata)
338
+ else:
339
+ torch.save(state_dict, file)
340
+
341
+ def create_network(
342
+ multiplier: float,
343
+ network_dim: Optional[int],
344
+ network_alpha: Optional[float],
345
+ text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
346
+ transformer,
347
+ neuron_dropout: Optional[float] = None,
348
+ skip_name: str = None,
349
+ **kwargs,
350
+ ):
351
+ if network_dim is None:
352
+ network_dim = 4 # default
353
+ if network_alpha is None:
354
+ network_alpha = 1.0
355
+
356
+ network = LoRANetwork(
357
+ text_encoder,
358
+ transformer,
359
+ multiplier=multiplier,
360
+ lora_dim=network_dim,
361
+ alpha=network_alpha,
362
+ dropout=neuron_dropout,
363
+ skip_name=skip_name,
364
+ varbose=True,
365
+ )
366
+ return network
367
+
368
+ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
369
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
370
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
371
+ if state_dict is None:
372
+ state_dict = load_file(lora_path, device=device)
373
+ else:
374
+ state_dict = state_dict
375
+ updates = defaultdict(dict)
376
+ for key, value in state_dict.items():
377
+ layer, elem = key.split('.', 1)
378
+ updates[layer][elem] = value
379
+
380
+ sequential_cpu_offload_flag = False
381
+ if pipeline.transformer.device == torch.device(type="meta"):
382
+ pipeline.remove_all_hooks()
383
+ sequential_cpu_offload_flag = True
384
+ offload_device = pipeline._offload_device
385
+
386
+ for layer, elems in updates.items():
387
+
388
+ if "lora_te" in layer:
389
+ if transformer_only:
390
+ continue
391
+ else:
392
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
393
+ curr_layer = pipeline.text_encoder
394
+ else:
395
+ layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
396
+ curr_layer = pipeline.transformer
397
+
398
+ try:
399
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
400
+ except Exception:
401
+ temp_name = layer_infos.pop(0)
402
+ while len(layer_infos) > -1:
403
+ try:
404
+ curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
405
+ break
406
+ except Exception:
407
+ try:
408
+ curr_layer = curr_layer.__getattr__(temp_name)
409
+ if len(layer_infos) > 0:
410
+ temp_name = layer_infos.pop(0)
411
+ elif len(layer_infos) == 0:
412
+ break
413
+ except Exception:
414
+ if len(layer_infos) == 0:
415
+ print('Error loading layer')
416
+ if len(temp_name) > 0:
417
+ temp_name += "_" + layer_infos.pop(0)
418
+ else:
419
+ temp_name = layer_infos.pop(0)
420
+
421
+ origin_dtype = curr_layer.weight.data.dtype
422
+ origin_device = curr_layer.weight.data.device
423
+
424
+ curr_layer = curr_layer.to(device, dtype)
425
+ weight_up = elems['lora_up.weight'].to(device, dtype)
426
+ weight_down = elems['lora_down.weight'].to(device, dtype)
427
+
428
+ if 'alpha' in elems.keys():
429
+ alpha = elems['alpha'].item() / weight_up.shape[1]
430
+ else:
431
+ alpha = 1.0
432
+
433
+ if len(weight_up.shape) == 4:
434
+ curr_layer.weight.data += multiplier * alpha * torch.mm(
435
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
436
+ ).unsqueeze(2).unsqueeze(3)
437
+ else:
438
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
439
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
440
+
441
+ if sequential_cpu_offload_flag:
442
+ pipeline.enable_sequential_cpu_offload(device=offload_device)
443
+ return pipeline
444
+
445
+ # TODO: Refactor with merge_lora.
446
+ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
447
+ """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
448
+ LORA_PREFIX_UNET = "lora_unet"
449
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
450
+ state_dict = load_file(lora_path, device=device)
451
+
452
+ updates = defaultdict(dict)
453
+ for key, value in state_dict.items():
454
+ layer, elem = key.split('.', 1)
455
+ updates[layer][elem] = value
456
+
457
+ sequential_cpu_offload_flag = False
458
+ if pipeline.transformer.device == torch.device(type="meta"):
459
+ pipeline.remove_all_hooks()
460
+ sequential_cpu_offload_flag = True
461
+
462
+ for layer, elems in updates.items():
463
+
464
+ if "lora_te" in layer:
465
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
466
+ curr_layer = pipeline.text_encoder
467
+ else:
468
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
469
+ curr_layer = pipeline.transformer
470
+
471
+ try:
472
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
473
+ except Exception:
474
+ temp_name = layer_infos.pop(0)
475
+ while len(layer_infos) > -1:
476
+ try:
477
+ curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
478
+ break
479
+ except Exception:
480
+ try:
481
+ curr_layer = curr_layer.__getattr__(temp_name)
482
+ if len(layer_infos) > 0:
483
+ temp_name = layer_infos.pop(0)
484
+ elif len(layer_infos) == 0:
485
+ break
486
+ except Exception:
487
+ if len(layer_infos) == 0:
488
+ print('Error loading layer')
489
+ if len(temp_name) > 0:
490
+ temp_name += "_" + layer_infos.pop(0)
491
+ else:
492
+ temp_name = layer_infos.pop(0)
493
+
494
+ origin_dtype = curr_layer.weight.data.dtype
495
+ origin_device = curr_layer.weight.data.device
496
+
497
+ curr_layer = curr_layer.to(device, dtype)
498
+ weight_up = elems['lora_up.weight'].to(device, dtype)
499
+ weight_down = elems['lora_down.weight'].to(device, dtype)
500
+
501
+ if 'alpha' in elems.keys():
502
+ alpha = elems['alpha'].item() / weight_up.shape[1]
503
+ else:
504
+ alpha = 1.0
505
+
506
+ if len(weight_up.shape) == 4:
507
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(
508
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
509
+ ).unsqueeze(2).unsqueeze(3)
510
+ else:
511
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
512
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
513
+
514
+ if sequential_cpu_offload_flag:
515
+ pipeline.enable_sequential_cpu_offload(device=device)
516
+ return pipeline