Spaces:
Running on Zero
Running on Zero
Upload 51 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- app.py +278 -4
- requirements.txt +61 -0
- sample/lime/first_frame.jpg +3 -0
- sample/lime/input_video.mp4 +3 -0
- sample/lime/prompt.json +3 -0
- sample/lime/quadmask_0.mp4 +3 -0
- sample/lime/segmentation_info.json +221 -0
- sample/moving_ball/first_frame.jpg +3 -0
- sample/moving_ball/input_video.mp4 +3 -0
- sample/moving_ball/prompt.json +3 -0
- sample/moving_ball/quadmask_0.mp4 +3 -0
- sample/pillow/input_video.mp4 +3 -0
- sample/pillow/prompt.json +3 -0
- sample/pillow/quadmask_0.mp4 +3 -0
- sample/pillow/segmentation_info.json +85 -0
- videox_fun/__init__.py +0 -0
- videox_fun/api/api.py +213 -0
- videox_fun/api/api_multi_nodes.py +215 -0
- videox_fun/data/bucket_sampler.py +390 -0
- videox_fun/data/dataset_image.py +76 -0
- videox_fun/data/dataset_image_video.py +1067 -0
- videox_fun/data/dataset_image_video_warped.py +1092 -0
- videox_fun/data/dataset_video.py +262 -0
- videox_fun/dist/__init__.py +40 -0
- videox_fun/dist/cogvideox_xfuser.py +116 -0
- videox_fun/dist/wan_xfuser.py +115 -0
- videox_fun/models/__init__.py +4 -0
- videox_fun/models/cache_utils.py +74 -0
- videox_fun/models/cogvideox_transformer3d.py +845 -0
- videox_fun/models/cogvideox_vae.py +1675 -0
- videox_fun/pipeline/__init__.py +2 -0
- videox_fun/pipeline/pipeline_cogvideox_fun.py +862 -0
- videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py +1244 -0
- videox_fun/pipeline/pipeline_wan_fun.py +558 -0
- videox_fun/reward/MPS/README.md +1 -0
- videox_fun/reward/MPS/trainer/models/base_model.py +7 -0
- videox_fun/reward/MPS/trainer/models/clip_model.py +154 -0
- videox_fun/reward/MPS/trainer/models/cross_modeling.py +291 -0
- videox_fun/reward/aesthetic_predictor_v2_5/__init__.py +13 -0
- videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py +133 -0
- videox_fun/reward/improved_aesthetic_predictor.py +49 -0
- videox_fun/reward/reward_fn.py +385 -0
- videox_fun/ui/cogvideox_fun_ui.py +667 -0
- videox_fun/ui/ui.py +290 -0
- videox_fun/ui/wan_fun_ui.py +630 -0
- videox_fun/utils/__init__.py +0 -0
- videox_fun/utils/discrete_sampler.py +46 -0
- videox_fun/utils/fp8_optimization.py +56 -0
- videox_fun/utils/lora_utils.py +516 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
sample/lime/first_frame.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
sample/moving_ball/first_frame.jpg filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,7 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
demo.launch()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VOID – Video Object and Interaction Deletion
|
| 3 |
+
Gradio demo for Hugging Face Spaces (ZeroGPU)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import imageio
|
| 14 |
+
import mediapy as media
|
| 15 |
+
import spaces
|
| 16 |
import gradio as gr
|
| 17 |
+
from huggingface_hub import hf_hub_download
|
| 18 |
+
from safetensors.torch import load_file
|
| 19 |
+
from diffusers import DDIMScheduler
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
# ── project imports ────────────────────────────────────────────────────────────
|
| 23 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 24 |
+
|
| 25 |
+
from videox_fun.models import (
|
| 26 |
+
AutoencoderKLCogVideoX,
|
| 27 |
+
CogVideoXTransformer3DModel,
|
| 28 |
+
T5EncoderModel,
|
| 29 |
+
T5Tokenizer,
|
| 30 |
+
)
|
| 31 |
+
from videox_fun.pipeline import CogVideoXFunInpaintPipeline
|
| 32 |
+
from videox_fun.utils.fp8_optimization import convert_weight_dtype_wrapper
|
| 33 |
+
from videox_fun.utils.utils import temporal_padding
|
| 34 |
+
|
| 35 |
+
# ── constants ──────────────────────────────────────────────────────────────────
|
| 36 |
+
# Set these env vars in your HF Space settings, or hardcode once weights are public.
|
| 37 |
+
BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP")
|
| 38 |
+
VOID_MODEL_ID = os.environ.get("VOID_MODEL_ID", "your-hf-username/VOID")
|
| 39 |
+
VOID_CKPT_FILE = "void_pass1.safetensors"
|
| 40 |
+
|
| 41 |
+
SAMPLE_SIZE = (384, 672) # H × W
|
| 42 |
+
MAX_VID_LEN = 197
|
| 43 |
+
TEMPORAL_WIN = 85
|
| 44 |
+
FPS = 12
|
| 45 |
+
WEIGHT_DTYPE = torch.bfloat16
|
| 46 |
+
NEG_PROMPT = (
|
| 47 |
+
"The video is not of a high quality, it has a low resolution. "
|
| 48 |
+
"Watermark present in each frame. The background is solid. "
|
| 49 |
+
"Strange body and strange trajectory. Distortion."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# ── model loading (once at startup, lives in CPU RAM between GPU requests) ─────
|
| 53 |
+
print("Loading VOID pipeline …")
|
| 54 |
+
|
| 55 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
| 56 |
+
BASE_MODEL_ID,
|
| 57 |
+
subfolder="transformer",
|
| 58 |
+
low_cpu_mem_usage=True,
|
| 59 |
+
torch_dtype=torch.float8_e4m3fn, # qfloat8 to save VRAM
|
| 60 |
+
use_vae_mask=True,
|
| 61 |
+
stack_mask=False,
|
| 62 |
+
).to(WEIGHT_DTYPE)
|
| 63 |
+
|
| 64 |
+
# Load VOID Pass-1 checkpoint
|
| 65 |
+
ckpt_path = hf_hub_download(repo_id=VOID_MODEL_ID, filename=VOID_CKPT_FILE)
|
| 66 |
+
state_dict = load_file(ckpt_path)
|
| 67 |
+
state_dict = state_dict.get("state_dict", state_dict)
|
| 68 |
+
|
| 69 |
+
# Adapt patch_embed channels if they differ (mask-conditioning channels added)
|
| 70 |
+
param_name = "patch_embed.proj.weight"
|
| 71 |
+
if state_dict[param_name].size(1) != transformer.state_dict()[param_name].size(1):
|
| 72 |
+
feat_dim = 16 * 8 # latent_channels * feat_scale
|
| 73 |
+
new_weight = transformer.state_dict()[param_name].clone()
|
| 74 |
+
new_weight[:, :feat_dim] = state_dict[param_name][:, :feat_dim]
|
| 75 |
+
new_weight[:, -feat_dim:] = state_dict[param_name][:, -feat_dim:]
|
| 76 |
+
state_dict[param_name] = new_weight
|
| 77 |
+
|
| 78 |
+
transformer.load_state_dict(state_dict, strict=False)
|
| 79 |
+
|
| 80 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
| 81 |
+
BASE_MODEL_ID, subfolder="vae"
|
| 82 |
+
).to(WEIGHT_DTYPE)
|
| 83 |
+
tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_ID, subfolder="tokenizer")
|
| 84 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 85 |
+
BASE_MODEL_ID, subfolder="text_encoder", torch_dtype=WEIGHT_DTYPE
|
| 86 |
+
)
|
| 87 |
+
scheduler = DDIMScheduler.from_pretrained(BASE_MODEL_ID, subfolder="scheduler")
|
| 88 |
+
|
| 89 |
+
pipeline = CogVideoXFunInpaintPipeline(
|
| 90 |
+
vae=vae,
|
| 91 |
+
tokenizer=tokenizer,
|
| 92 |
+
text_encoder=text_encoder,
|
| 93 |
+
transformer=transformer,
|
| 94 |
+
scheduler=scheduler,
|
| 95 |
+
)
|
| 96 |
+
convert_weight_dtype_wrapper(transformer, WEIGHT_DTYPE)
|
| 97 |
+
pipeline.enable_model_cpu_offload()
|
| 98 |
+
|
| 99 |
+
print("VOID pipeline ready.")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ── helpers ────────────────────────────────────────────────────────────────────
|
| 103 |
+
def load_video_tensor(path: str) -> torch.Tensor:
|
| 104 |
+
"""Return (1, C, T, H, W) float32 in [0, 1] resized to SAMPLE_SIZE."""
|
| 105 |
+
frames = media.read_video(path)
|
| 106 |
+
t = torch.from_numpy(np.array(frames))[:MAX_VID_LEN] # (T, H, W, C)
|
| 107 |
+
t = t.permute(3, 0, 1, 2).float() / 255.0 # (C, T, H, W)
|
| 108 |
+
t = F.interpolate(t, SAMPLE_SIZE, mode="area").unsqueeze(0)
|
| 109 |
+
return t
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_quadmask_tensor(path: str) -> torch.Tensor:
|
| 113 |
+
"""
|
| 114 |
+
Return (1, 1, T, H, W) float32 in [0, 1].
|
| 115 |
+
|
| 116 |
+
Quadmask pixel values:
|
| 117 |
+
0 → primary object (to erase)
|
| 118 |
+
63 → overlap / interaction zone
|
| 119 |
+
127 → affected region (shadows, reflections …)
|
| 120 |
+
255 → background (keep)
|
| 121 |
+
|
| 122 |
+
After quantisation the mask is inverted so 255 = "erase", 0 = "keep",
|
| 123 |
+
matching the pipeline's internal convention.
|
| 124 |
+
"""
|
| 125 |
+
frames = media.read_video(path)[:MAX_VID_LEN]
|
| 126 |
+
if frames.ndim == 4:
|
| 127 |
+
frames = frames[..., 0] # take first channel, grayscale
|
| 128 |
+
m = torch.from_numpy(np.array(frames)).unsqueeze(0).float() # (1, T, H, W)
|
| 129 |
+
m = F.interpolate(m, SAMPLE_SIZE, mode="area").unsqueeze(0) # (1, 1, T, H, W)
|
| 130 |
+
|
| 131 |
+
# Quantise to four canonical values
|
| 132 |
+
m = torch.where(m <= 31, torch.zeros_like(m), m)
|
| 133 |
+
m = torch.where((m > 31) & (m <= 95), torch.full_like(m, 63), m)
|
| 134 |
+
m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127), m)
|
| 135 |
+
m = torch.where(m > 191, torch.full_like(m, 255), m)
|
| 136 |
+
|
| 137 |
+
m = 255.0 - m # invert
|
| 138 |
+
return m / 255.0
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def tensor_to_mp4(video: torch.Tensor) -> str:
|
| 142 |
+
"""Save (1, C, T, H, W) in [0, 1] to a temp mp4 and return the path."""
|
| 143 |
+
frames = video[0].permute(1, 2, 3, 0).cpu().float().numpy() # (T, H, W, C)
|
| 144 |
+
frames = (frames * 255).clip(0, 255).astype(np.uint8)
|
| 145 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
| 146 |
+
imageio.mimsave(tmp.name, frames, fps=FPS)
|
| 147 |
+
return tmp.name
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ── inference ──────────────────────────────────────────────────────────────────
|
| 151 |
+
@spaces.GPU(duration=300)
|
| 152 |
+
def run_inpaint(
|
| 153 |
+
input_video_path: str,
|
| 154 |
+
mask_video_path: str,
|
| 155 |
+
prompt: str,
|
| 156 |
+
num_steps: int,
|
| 157 |
+
guidance_scale: float,
|
| 158 |
+
seed: int,
|
| 159 |
+
) -> str:
|
| 160 |
+
if not input_video_path or not mask_video_path:
|
| 161 |
+
raise gr.Error("Please upload both an input video and a quadmask video.")
|
| 162 |
+
if not prompt.strip():
|
| 163 |
+
raise gr.Error("Please enter a prompt describing the scene after removal.")
|
| 164 |
+
|
| 165 |
+
generator = torch.Generator(device="cuda").manual_seed(int(seed))
|
| 166 |
+
|
| 167 |
+
input_video = load_video_tensor(input_video_path)
|
| 168 |
+
input_mask = load_quadmask_tensor(mask_video_path)
|
| 169 |
+
|
| 170 |
+
input_video = temporal_padding(input_video, min_length=TEMPORAL_WIN, max_length=MAX_VID_LEN)
|
| 171 |
+
input_mask = temporal_padding(input_mask, min_length=TEMPORAL_WIN, max_length=MAX_VID_LEN)
|
| 172 |
+
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
result = pipeline(
|
| 175 |
+
prompt=prompt,
|
| 176 |
+
negative_prompt=NEG_PROMPT,
|
| 177 |
+
height=SAMPLE_SIZE[0],
|
| 178 |
+
width=SAMPLE_SIZE[1],
|
| 179 |
+
num_frames=TEMPORAL_WIN,
|
| 180 |
+
video=input_video,
|
| 181 |
+
mask_video=input_mask,
|
| 182 |
+
generator=generator,
|
| 183 |
+
guidance_scale=guidance_scale,
|
| 184 |
+
num_inference_steps=num_steps,
|
| 185 |
+
strength=1.0,
|
| 186 |
+
use_trimask=True,
|
| 187 |
+
use_vae_mask=True,
|
| 188 |
+
stack_mask=False,
|
| 189 |
+
zero_out_mask_region=False,
|
| 190 |
+
).videos
|
| 191 |
+
|
| 192 |
+
return tensor_to_mp4(result)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ── Gradio UI ──────────────────────────────────────────────────────────────────
|
| 196 |
+
QUADMASK_EXPLAINER = """
|
| 197 |
+
### Quadmask format
|
| 198 |
+
|
| 199 |
+
The quadmask is a **grayscale video** where each pixel value encodes what role that region plays:
|
| 200 |
+
|
| 201 |
+
| Pixel value | Meaning |
|
| 202 |
+
|-------------|---------|
|
| 203 |
+
| **0** (black) | Primary object to remove |
|
| 204 |
+
| **63** (dark grey) | Overlap / interaction zone |
|
| 205 |
+
| **127** (mid grey) | Affected region — shadows, reflections, secondary effects |
|
| 206 |
+
| **255** (white) | Background — keep as-is |
|
| 207 |
+
|
| 208 |
+
Use the **VLM-Mask-Reasoner** pipeline included in the repo to generate quadmasks automatically.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
SAMPLE_DIR = os.path.join(os.path.dirname(__file__), "sample")
|
| 212 |
+
EXAMPLES = [
|
| 213 |
+
[
|
| 214 |
+
os.path.join(SAMPLE_DIR, "lime", "input_video.mp4"),
|
| 215 |
+
os.path.join(SAMPLE_DIR, "lime", "quadmask_0.mp4"),
|
| 216 |
+
"A lime falls on the table.",
|
| 217 |
+
30, 1.0, 42,
|
| 218 |
+
],
|
| 219 |
+
[
|
| 220 |
+
os.path.join(SAMPLE_DIR, "moving_ball", "input_video.mp4"),
|
| 221 |
+
os.path.join(SAMPLE_DIR, "moving_ball", "quadmask_0.mp4"),
|
| 222 |
+
"A ball rolls off the table.",
|
| 223 |
+
30, 1.0, 42,
|
| 224 |
+
],
|
| 225 |
+
[
|
| 226 |
+
os.path.join(SAMPLE_DIR, "pillow", "input_video.mp4"),
|
| 227 |
+
os.path.join(SAMPLE_DIR, "pillow", "quadmask_0.mp4"),
|
| 228 |
+
"Two pillows placed on the table.",
|
| 229 |
+
30, 1.0, 42,
|
| 230 |
+
],
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
with gr.Blocks(title="VOID – Video Object & Interaction Deletion") as demo:
|
| 234 |
+
gr.Markdown(
|
| 235 |
+
"""
|
| 236 |
+
# VOID – Video Object and Interaction Deletion
|
| 237 |
+
|
| 238 |
+
Upload a video and its **quadmask**, enter a prompt describing the scene *after* removal,
|
| 239 |
+
and VOID will erase the object along with its physical interactions (shadows, deformations, secondary motion).
|
| 240 |
+
|
| 241 |
+
> Built on **CogVideoX-Fun-V1.5-5B** fine-tuned for interaction-aware video inpainting.
|
| 242 |
+
"""
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
with gr.Row():
|
| 246 |
+
with gr.Column():
|
| 247 |
+
input_video = gr.Video(label="Input video", sources=["upload"])
|
| 248 |
+
mask_video = gr.Video(label="Quadmask video", sources=["upload"])
|
| 249 |
+
prompt = gr.Textbox(
|
| 250 |
+
label="Prompt — describe the scene after removal",
|
| 251 |
+
placeholder="e.g. A wooden table with nothing on it.",
|
| 252 |
+
lines=2,
|
| 253 |
+
)
|
| 254 |
+
with gr.Accordion("Advanced settings", open=False):
|
| 255 |
+
num_steps = gr.Slider(10, 50, value=30, step=1, label="Inference steps")
|
| 256 |
+
guidance_scale = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="Guidance scale")
|
| 257 |
+
seed = gr.Number(value=42, label="Seed", precision=0)
|
| 258 |
+
run_btn = gr.Button("Run VOID", variant="primary")
|
| 259 |
+
|
| 260 |
+
with gr.Column():
|
| 261 |
+
output_video = gr.Video(label="Inpainted output", interactive=False)
|
| 262 |
+
|
| 263 |
+
gr.Markdown(QUADMASK_EXPLAINER)
|
| 264 |
+
|
| 265 |
+
gr.Examples(
|
| 266 |
+
examples=EXAMPLES,
|
| 267 |
+
inputs=[input_video, mask_video, prompt, num_steps, guidance_scale, seed],
|
| 268 |
+
outputs=[output_video],
|
| 269 |
+
fn=run_inpaint,
|
| 270 |
+
cache_examples=True,
|
| 271 |
+
label="Sample sequences — click to load and run",
|
| 272 |
+
)
|
| 273 |
|
| 274 |
+
run_btn.click(
|
| 275 |
+
fn=run_inpaint,
|
| 276 |
+
inputs=[input_video, mask_video, prompt, num_steps, guidance_scale, seed],
|
| 277 |
+
outputs=[output_video],
|
| 278 |
+
)
|
| 279 |
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core deep learning
|
| 2 |
+
torch==2.7.1
|
| 3 |
+
torchvision==0.22.1
|
| 4 |
+
torchdiffeq==0.2.5
|
| 5 |
+
torchsde==0.2.6
|
| 6 |
+
|
| 7 |
+
# Diffusion / generation
|
| 8 |
+
diffusers==0.33.1
|
| 9 |
+
accelerate==1.12.0
|
| 10 |
+
transformers==4.57.1
|
| 11 |
+
safetensors==0.6.2
|
| 12 |
+
peft==0.17.1
|
| 13 |
+
|
| 14 |
+
# Training utilities
|
| 15 |
+
deepspeed==0.17.6
|
| 16 |
+
came-pytorch==0.1.3
|
| 17 |
+
tensorboard==2.20.0
|
| 18 |
+
|
| 19 |
+
# Vision / video
|
| 20 |
+
opencv-python==4.10.0.84
|
| 21 |
+
scikit-image==0.25.2
|
| 22 |
+
imageio==2.37.0
|
| 23 |
+
imageio-ffmpeg==0.6.0
|
| 24 |
+
mediapy==1.2.4
|
| 25 |
+
decord==0.6.0
|
| 26 |
+
kornia==0.8.1
|
| 27 |
+
albumentations==2.0.8
|
| 28 |
+
timm==1.0.19
|
| 29 |
+
tomesd==0.1.3
|
| 30 |
+
Pillow==11.3.0
|
| 31 |
+
|
| 32 |
+
# Data / ML utilities
|
| 33 |
+
numpy==1.26.4
|
| 34 |
+
scipy==1.14.0
|
| 35 |
+
scikit-learn==1.7.2
|
| 36 |
+
datasets==4.0.0
|
| 37 |
+
einops==0.8.0
|
| 38 |
+
|
| 39 |
+
# Config / logging
|
| 40 |
+
omegaconf==2.3.0
|
| 41 |
+
ml_collections==1.1.0
|
| 42 |
+
absl-py==2.3.1
|
| 43 |
+
loguru==0.7.3
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
matplotlib==3.10.6
|
| 46 |
+
|
| 47 |
+
# NLP
|
| 48 |
+
sentencepiece==0.2.1
|
| 49 |
+
ftfy==6.1.1
|
| 50 |
+
beautifulsoup4==4.13.5
|
| 51 |
+
|
| 52 |
+
# Misc
|
| 53 |
+
func-timeout==4.3.5
|
| 54 |
+
requests==2.32.5
|
| 55 |
+
packaging==25.0
|
| 56 |
+
|
| 57 |
+
# Optional: Gradio UI (only needed for app.py / demo)
|
| 58 |
+
# gradio>=3.41.2,<=3.48.0
|
| 59 |
+
|
| 60 |
+
# Note: SAM2 must be installed separately per the instructions at
|
| 61 |
+
# https://github.com/facebookresearch/sam2?tab=readme-ov-file#installation
|
sample/lime/first_frame.jpg
ADDED
|
Git LFS Details
|
sample/lime/input_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0efabfbfc85bf29d11ac0f734eccf5dc824c511333c15953b73d3e357d7d9a87
|
| 3 |
+
size 3892459
|
sample/lime/prompt.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bg": "A lime falls on the table."
|
| 3 |
+
}
|
sample/lime/quadmask_0.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00a01b7fb47107edcbfd5a036d6d7b1097ea8624df9c2440d184ddfa90a8bdd5
|
| 3 |
+
size 1907329
|
sample/lime/segmentation_info.json
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"total_frames": 46,
|
| 3 |
+
"frame_width": 3840,
|
| 4 |
+
"frame_height": 2160,
|
| 5 |
+
"fps": 12.0,
|
| 6 |
+
"num_points": 25,
|
| 7 |
+
"points_by_frame": {
|
| 8 |
+
"0": [
|
| 9 |
+
[
|
| 10 |
+
2126,
|
| 11 |
+
1099
|
| 12 |
+
],
|
| 13 |
+
[
|
| 14 |
+
2366,
|
| 15 |
+
1099
|
| 16 |
+
],
|
| 17 |
+
[
|
| 18 |
+
2683,
|
| 19 |
+
1080
|
| 20 |
+
],
|
| 21 |
+
[
|
| 22 |
+
2784,
|
| 23 |
+
1176
|
| 24 |
+
],
|
| 25 |
+
[
|
| 26 |
+
2640,
|
| 27 |
+
1176
|
| 28 |
+
],
|
| 29 |
+
[
|
| 30 |
+
2539,
|
| 31 |
+
1176
|
| 32 |
+
],
|
| 33 |
+
[
|
| 34 |
+
2318,
|
| 35 |
+
1176
|
| 36 |
+
],
|
| 37 |
+
[
|
| 38 |
+
2116,
|
| 39 |
+
1291
|
| 40 |
+
],
|
| 41 |
+
[
|
| 42 |
+
2496,
|
| 43 |
+
1291
|
| 44 |
+
],
|
| 45 |
+
[
|
| 46 |
+
2654,
|
| 47 |
+
1286
|
| 48 |
+
],
|
| 49 |
+
[
|
| 50 |
+
2654,
|
| 51 |
+
1406
|
| 52 |
+
],
|
| 53 |
+
[
|
| 54 |
+
2342,
|
| 55 |
+
1406
|
| 56 |
+
],
|
| 57 |
+
[
|
| 58 |
+
2342,
|
| 59 |
+
1776
|
| 60 |
+
],
|
| 61 |
+
[
|
| 62 |
+
2620,
|
| 63 |
+
1776
|
| 64 |
+
],
|
| 65 |
+
[
|
| 66 |
+
2539,
|
| 67 |
+
1924
|
| 68 |
+
],
|
| 69 |
+
[
|
| 70 |
+
2304,
|
| 71 |
+
1972
|
| 72 |
+
],
|
| 73 |
+
[
|
| 74 |
+
2217,
|
| 75 |
+
1992
|
| 76 |
+
],
|
| 77 |
+
[
|
| 78 |
+
2385,
|
| 79 |
+
2030
|
| 80 |
+
],
|
| 81 |
+
[
|
| 82 |
+
2596,
|
| 83 |
+
2025
|
| 84 |
+
],
|
| 85 |
+
[
|
| 86 |
+
2673,
|
| 87 |
+
1987
|
| 88 |
+
],
|
| 89 |
+
[
|
| 90 |
+
2217,
|
| 91 |
+
1776
|
| 92 |
+
],
|
| 93 |
+
[
|
| 94 |
+
2198,
|
| 95 |
+
1660
|
| 96 |
+
],
|
| 97 |
+
[
|
| 98 |
+
2452,
|
| 99 |
+
1588
|
| 100 |
+
],
|
| 101 |
+
[
|
| 102 |
+
2294,
|
| 103 |
+
1483
|
| 104 |
+
],
|
| 105 |
+
[
|
| 106 |
+
2270,
|
| 107 |
+
1358
|
| 108 |
+
]
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
"video_path": "limecoke.mp4",
|
| 112 |
+
"instruction": "",
|
| 113 |
+
"primary_points_by_frame": {
|
| 114 |
+
"0": [
|
| 115 |
+
[
|
| 116 |
+
2126,
|
| 117 |
+
1099
|
| 118 |
+
],
|
| 119 |
+
[
|
| 120 |
+
2366,
|
| 121 |
+
1099
|
| 122 |
+
],
|
| 123 |
+
[
|
| 124 |
+
2683,
|
| 125 |
+
1080
|
| 126 |
+
],
|
| 127 |
+
[
|
| 128 |
+
2784,
|
| 129 |
+
1176
|
| 130 |
+
],
|
| 131 |
+
[
|
| 132 |
+
2640,
|
| 133 |
+
1176
|
| 134 |
+
],
|
| 135 |
+
[
|
| 136 |
+
2539,
|
| 137 |
+
1176
|
| 138 |
+
],
|
| 139 |
+
[
|
| 140 |
+
2318,
|
| 141 |
+
1176
|
| 142 |
+
],
|
| 143 |
+
[
|
| 144 |
+
2116,
|
| 145 |
+
1291
|
| 146 |
+
],
|
| 147 |
+
[
|
| 148 |
+
2496,
|
| 149 |
+
1291
|
| 150 |
+
],
|
| 151 |
+
[
|
| 152 |
+
2654,
|
| 153 |
+
1286
|
| 154 |
+
],
|
| 155 |
+
[
|
| 156 |
+
2654,
|
| 157 |
+
1406
|
| 158 |
+
],
|
| 159 |
+
[
|
| 160 |
+
2342,
|
| 161 |
+
1406
|
| 162 |
+
],
|
| 163 |
+
[
|
| 164 |
+
2342,
|
| 165 |
+
1776
|
| 166 |
+
],
|
| 167 |
+
[
|
| 168 |
+
2620,
|
| 169 |
+
1776
|
| 170 |
+
],
|
| 171 |
+
[
|
| 172 |
+
2539,
|
| 173 |
+
1924
|
| 174 |
+
],
|
| 175 |
+
[
|
| 176 |
+
2304,
|
| 177 |
+
1972
|
| 178 |
+
],
|
| 179 |
+
[
|
| 180 |
+
2217,
|
| 181 |
+
1992
|
| 182 |
+
],
|
| 183 |
+
[
|
| 184 |
+
2385,
|
| 185 |
+
2030
|
| 186 |
+
],
|
| 187 |
+
[
|
| 188 |
+
2596,
|
| 189 |
+
2025
|
| 190 |
+
],
|
| 191 |
+
[
|
| 192 |
+
2673,
|
| 193 |
+
1987
|
| 194 |
+
],
|
| 195 |
+
[
|
| 196 |
+
2217,
|
| 197 |
+
1776
|
| 198 |
+
],
|
| 199 |
+
[
|
| 200 |
+
2198,
|
| 201 |
+
1660
|
| 202 |
+
],
|
| 203 |
+
[
|
| 204 |
+
2452,
|
| 205 |
+
1588
|
| 206 |
+
],
|
| 207 |
+
[
|
| 208 |
+
2294,
|
| 209 |
+
1483
|
| 210 |
+
],
|
| 211 |
+
[
|
| 212 |
+
2270,
|
| 213 |
+
1358
|
| 214 |
+
]
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
"primary_frames": [
|
| 218 |
+
0
|
| 219 |
+
],
|
| 220 |
+
"first_appears_frame": 0
|
| 221 |
+
}
|
sample/moving_ball/first_frame.jpg
ADDED
|
Git LFS Details
|
sample/moving_ball/input_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e07906cc204ba26c0dd05eed545030cb7e79f2742e983ff0b04d2d9c3c762d29
|
| 3 |
+
size 2014662
|
sample/moving_ball/prompt.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bg": "A ball rolls off the table."
|
| 3 |
+
}
|
sample/moving_ball/quadmask_0.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5904642de05a65f210bd49e3c24b7d0657ef57ff40eb9baafd562962c9dd9189
|
| 3 |
+
size 2485881
|
sample/pillow/input_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ca3e6b666497e053491772e8f0317e22520c63ebaa8896b8378757d016e0f75
|
| 3 |
+
size 2960087
|
sample/pillow/prompt.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bg": "Two pillows placed on the table."
|
| 3 |
+
}
|
sample/pillow/quadmask_0.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7eb70257593da06f682a3ddda54a9d260d4fc514f645237f5ca74b08f8da61a6
|
| 3 |
+
size 2
|
sample/pillow/segmentation_info.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"total_frames": 62,
|
| 3 |
+
"frame_width": 3840,
|
| 4 |
+
"frame_height": 2160,
|
| 5 |
+
"fps": 12.0,
|
| 6 |
+
"num_points": 8,
|
| 7 |
+
"points_by_frame": {
|
| 8 |
+
"0": [
|
| 9 |
+
[
|
| 10 |
+
1507,
|
| 11 |
+
724
|
| 12 |
+
],
|
| 13 |
+
[
|
| 14 |
+
1363,
|
| 15 |
+
638
|
| 16 |
+
],
|
| 17 |
+
[
|
| 18 |
+
1190,
|
| 19 |
+
475
|
| 20 |
+
],
|
| 21 |
+
[
|
| 22 |
+
1276,
|
| 23 |
+
187
|
| 24 |
+
],
|
| 25 |
+
[
|
| 26 |
+
1545,
|
| 27 |
+
168
|
| 28 |
+
],
|
| 29 |
+
[
|
| 30 |
+
1660,
|
| 31 |
+
259
|
| 32 |
+
],
|
| 33 |
+
[
|
| 34 |
+
1684,
|
| 35 |
+
393
|
| 36 |
+
],
|
| 37 |
+
[
|
| 38 |
+
1579,
|
| 39 |
+
825
|
| 40 |
+
]
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
"video_path": "teaser3/weight_on_pillow.mp4",
|
| 44 |
+
"instruction": "segment the weight",
|
| 45 |
+
"primary_points_by_frame": {
|
| 46 |
+
"0": [
|
| 47 |
+
[
|
| 48 |
+
1507,
|
| 49 |
+
724
|
| 50 |
+
],
|
| 51 |
+
[
|
| 52 |
+
1363,
|
| 53 |
+
638
|
| 54 |
+
],
|
| 55 |
+
[
|
| 56 |
+
1190,
|
| 57 |
+
475
|
| 58 |
+
],
|
| 59 |
+
[
|
| 60 |
+
1276,
|
| 61 |
+
187
|
| 62 |
+
],
|
| 63 |
+
[
|
| 64 |
+
1545,
|
| 65 |
+
168
|
| 66 |
+
],
|
| 67 |
+
[
|
| 68 |
+
1660,
|
| 69 |
+
259
|
| 70 |
+
],
|
| 71 |
+
[
|
| 72 |
+
1684,
|
| 73 |
+
393
|
| 74 |
+
],
|
| 75 |
+
[
|
| 76 |
+
1579,
|
| 77 |
+
825
|
| 78 |
+
]
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
"primary_frames": [
|
| 82 |
+
0
|
| 83 |
+
],
|
| 84 |
+
"first_appears_frame": 0
|
| 85 |
+
}
|
videox_fun/__init__.py
ADDED
|
File without changes
|
videox_fun/api/api.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import gc
|
| 3 |
+
import hashlib
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import requests
|
| 11 |
+
import torch
|
| 12 |
+
from fastapi import FastAPI
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Function to encode a file to Base64
|
| 17 |
+
def encode_file_to_base64(file_path):
|
| 18 |
+
with open(file_path, "rb") as file:
|
| 19 |
+
# Encode the data to Base64
|
| 20 |
+
file_base64 = base64.b64encode(file.read())
|
| 21 |
+
return file_base64
|
| 22 |
+
|
| 23 |
+
def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
|
| 24 |
+
@app.post("/videox_fun/update_edition")
|
| 25 |
+
def _update_edition_api(
|
| 26 |
+
datas: dict,
|
| 27 |
+
):
|
| 28 |
+
edition = datas.get('edition', 'v2')
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
controller.update_edition(
|
| 32 |
+
edition
|
| 33 |
+
)
|
| 34 |
+
comment = "Success"
|
| 35 |
+
except Exception as e:
|
| 36 |
+
torch.cuda.empty_cache()
|
| 37 |
+
comment = f"Error. error information is {str(e)}"
|
| 38 |
+
|
| 39 |
+
return {"message": comment}
|
| 40 |
+
|
| 41 |
+
def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
|
| 42 |
+
@app.post("/videox_fun/update_diffusion_transformer")
|
| 43 |
+
def _update_diffusion_transformer_api(
|
| 44 |
+
datas: dict,
|
| 45 |
+
):
|
| 46 |
+
diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
controller.update_diffusion_transformer(
|
| 50 |
+
diffusion_transformer_path
|
| 51 |
+
)
|
| 52 |
+
comment = "Success"
|
| 53 |
+
except Exception as e:
|
| 54 |
+
torch.cuda.empty_cache()
|
| 55 |
+
comment = f"Error. error information is {str(e)}"
|
| 56 |
+
|
| 57 |
+
return {"message": comment}
|
| 58 |
+
|
| 59 |
+
def download_from_url(url, timeout=10):
|
| 60 |
+
try:
|
| 61 |
+
response = requests.get(url, timeout=timeout)
|
| 62 |
+
response.raise_for_status() # 检查请求是否成功
|
| 63 |
+
return response.content
|
| 64 |
+
except requests.exceptions.RequestException as e:
|
| 65 |
+
print(f"Error downloading from {url}: {e}")
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
def save_base64_video(base64_string):
|
| 69 |
+
video_data = base64.b64decode(base64_string)
|
| 70 |
+
|
| 71 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 72 |
+
filename = f"{md5_hash}.mp4"
|
| 73 |
+
|
| 74 |
+
temp_dir = tempfile.gettempdir()
|
| 75 |
+
file_path = os.path.join(temp_dir, filename)
|
| 76 |
+
|
| 77 |
+
with open(file_path, 'wb') as video_file:
|
| 78 |
+
video_file.write(video_data)
|
| 79 |
+
|
| 80 |
+
return file_path
|
| 81 |
+
|
| 82 |
+
def save_base64_image(base64_string):
|
| 83 |
+
video_data = base64.b64decode(base64_string)
|
| 84 |
+
|
| 85 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 86 |
+
filename = f"{md5_hash}.jpg"
|
| 87 |
+
|
| 88 |
+
temp_dir = tempfile.gettempdir()
|
| 89 |
+
file_path = os.path.join(temp_dir, filename)
|
| 90 |
+
|
| 91 |
+
with open(file_path, 'wb') as video_file:
|
| 92 |
+
video_file.write(video_data)
|
| 93 |
+
|
| 94 |
+
return file_path
|
| 95 |
+
|
| 96 |
+
def save_url_video(url):
|
| 97 |
+
video_data = download_from_url(url)
|
| 98 |
+
if video_data:
|
| 99 |
+
return save_base64_video(base64.b64encode(video_data))
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
def save_url_image(url):
|
| 103 |
+
image_data = download_from_url(url)
|
| 104 |
+
if image_data:
|
| 105 |
+
return save_base64_image(base64.b64encode(image_data))
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
| 109 |
+
@app.post("/videox_fun/infer_forward")
|
| 110 |
+
def _infer_forward_api(
|
| 111 |
+
datas: dict,
|
| 112 |
+
):
|
| 113 |
+
base_model_path = datas.get('base_model_path', 'none')
|
| 114 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
| 115 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
| 116 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
| 117 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
| 118 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
| 119 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
| 120 |
+
resize_method = datas.get('resize_method', "Generate by")
|
| 121 |
+
width_slider = datas.get('width_slider', 672)
|
| 122 |
+
height_slider = datas.get('height_slider', 384)
|
| 123 |
+
base_resolution = datas.get('base_resolution', 512)
|
| 124 |
+
is_image = datas.get('is_image', False)
|
| 125 |
+
generation_method = datas.get('generation_method', False)
|
| 126 |
+
length_slider = datas.get('length_slider', 49)
|
| 127 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
| 128 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
| 129 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
| 130 |
+
start_image = datas.get('start_image', None)
|
| 131 |
+
end_image = datas.get('end_image', None)
|
| 132 |
+
validation_video = datas.get('validation_video', None)
|
| 133 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
| 134 |
+
control_video = datas.get('control_video', None)
|
| 135 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
| 136 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
| 137 |
+
|
| 138 |
+
generation_method = "Image Generation" if is_image else generation_method
|
| 139 |
+
|
| 140 |
+
if start_image is not None:
|
| 141 |
+
if start_image.startswith('http'):
|
| 142 |
+
start_image = save_url_image(start_image)
|
| 143 |
+
start_image = [Image.open(start_image)]
|
| 144 |
+
else:
|
| 145 |
+
start_image = base64.b64decode(start_image)
|
| 146 |
+
start_image = [Image.open(BytesIO(start_image))]
|
| 147 |
+
|
| 148 |
+
if end_image is not None:
|
| 149 |
+
if end_image.startswith('http'):
|
| 150 |
+
end_image = save_url_image(end_image)
|
| 151 |
+
end_image = [Image.open(end_image)]
|
| 152 |
+
else:
|
| 153 |
+
end_image = base64.b64decode(end_image)
|
| 154 |
+
end_image = [Image.open(BytesIO(end_image))]
|
| 155 |
+
|
| 156 |
+
if validation_video is not None:
|
| 157 |
+
if validation_video.startswith('http'):
|
| 158 |
+
validation_video = save_url_video(validation_video)
|
| 159 |
+
else:
|
| 160 |
+
validation_video = save_base64_video(validation_video)
|
| 161 |
+
|
| 162 |
+
if validation_video_mask is not None:
|
| 163 |
+
if validation_video_mask.startswith('http'):
|
| 164 |
+
validation_video_mask = save_url_image(validation_video_mask)
|
| 165 |
+
else:
|
| 166 |
+
validation_video_mask = save_base64_image(validation_video_mask)
|
| 167 |
+
|
| 168 |
+
if control_video is not None:
|
| 169 |
+
if control_video.startswith('http'):
|
| 170 |
+
control_video = save_url_video(control_video)
|
| 171 |
+
else:
|
| 172 |
+
control_video = save_base64_video(control_video)
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
save_sample_path, comment = controller.generate(
|
| 176 |
+
"",
|
| 177 |
+
base_model_path,
|
| 178 |
+
lora_model_path,
|
| 179 |
+
lora_alpha_slider,
|
| 180 |
+
prompt_textbox,
|
| 181 |
+
negative_prompt_textbox,
|
| 182 |
+
sampler_dropdown,
|
| 183 |
+
sample_step_slider,
|
| 184 |
+
resize_method,
|
| 185 |
+
width_slider,
|
| 186 |
+
height_slider,
|
| 187 |
+
base_resolution,
|
| 188 |
+
generation_method,
|
| 189 |
+
length_slider,
|
| 190 |
+
overlap_video_length,
|
| 191 |
+
partial_video_length,
|
| 192 |
+
cfg_scale_slider,
|
| 193 |
+
start_image,
|
| 194 |
+
end_image,
|
| 195 |
+
validation_video,
|
| 196 |
+
validation_video_mask,
|
| 197 |
+
control_video,
|
| 198 |
+
denoise_strength,
|
| 199 |
+
seed_textbox,
|
| 200 |
+
is_api = True,
|
| 201 |
+
)
|
| 202 |
+
except Exception as e:
|
| 203 |
+
gc.collect()
|
| 204 |
+
torch.cuda.empty_cache()
|
| 205 |
+
torch.cuda.ipc_collect()
|
| 206 |
+
save_sample_path = ""
|
| 207 |
+
comment = f"Error. error information is {str(e)}"
|
| 208 |
+
return {"message": comment}
|
| 209 |
+
|
| 210 |
+
if save_sample_path != "":
|
| 211 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 212 |
+
else:
|
| 213 |
+
return {"message": comment, "save_sample_path": save_sample_path}
|
videox_fun/api/api_multi_nodes.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
|
| 2 |
+
import base64
|
| 3 |
+
import gc
|
| 4 |
+
import os
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import torch
|
| 9 |
+
from fastapi import FastAPI, HTTPException
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from .api import (encode_file_to_base64, save_base64_image, save_base64_video,
|
| 13 |
+
save_url_image, save_url_video)
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import ray
|
| 17 |
+
except:
|
| 18 |
+
print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
|
| 19 |
+
ray = None
|
| 20 |
+
|
| 21 |
+
if ray is not None:
|
| 22 |
+
@ray.remote(num_gpus=1)
|
| 23 |
+
class MultiNodesGenerator:
|
| 24 |
+
def __init__(
|
| 25 |
+
self, rank: int, world_size: int, Controller,
|
| 26 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 27 |
+
config_path=None, ulysses_degree=1, ring_degree=1,
|
| 28 |
+
enable_teacache=None, teacache_threshold=None,
|
| 29 |
+
num_skip_start_steps=None, teacache_offload=None, weight_dtype=None,
|
| 30 |
+
savedir_sample=None,
|
| 31 |
+
):
|
| 32 |
+
# Set PyTorch distributed environment variables
|
| 33 |
+
os.environ["RANK"] = str(rank)
|
| 34 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
| 35 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 36 |
+
os.environ["MASTER_PORT"] = "29500"
|
| 37 |
+
|
| 38 |
+
self.rank = rank
|
| 39 |
+
self.controller = Controller(
|
| 40 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
|
| 41 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree, enable_teacache=enable_teacache, teacache_threshold=teacache_threshold, num_skip_start_steps=num_skip_start_steps,
|
| 42 |
+
teacache_offload=teacache_offload, weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def generate(self, datas):
|
| 46 |
+
try:
|
| 47 |
+
base_model_path = datas.get('base_model_path', 'none')
|
| 48 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
| 49 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
| 50 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
| 51 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
| 52 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
| 53 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
| 54 |
+
resize_method = datas.get('resize_method', "Generate by")
|
| 55 |
+
width_slider = datas.get('width_slider', 672)
|
| 56 |
+
height_slider = datas.get('height_slider', 384)
|
| 57 |
+
base_resolution = datas.get('base_resolution', 512)
|
| 58 |
+
is_image = datas.get('is_image', False)
|
| 59 |
+
generation_method = datas.get('generation_method', False)
|
| 60 |
+
length_slider = datas.get('length_slider', 49)
|
| 61 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
| 62 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
| 63 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
| 64 |
+
start_image = datas.get('start_image', None)
|
| 65 |
+
end_image = datas.get('end_image', None)
|
| 66 |
+
validation_video = datas.get('validation_video', None)
|
| 67 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
| 68 |
+
control_video = datas.get('control_video', None)
|
| 69 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
| 70 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
| 71 |
+
|
| 72 |
+
generation_method = "Image Generation" if is_image else generation_method
|
| 73 |
+
|
| 74 |
+
if start_image is not None:
|
| 75 |
+
if start_image.startswith('http'):
|
| 76 |
+
start_image = save_url_image(start_image)
|
| 77 |
+
start_image = [Image.open(start_image)]
|
| 78 |
+
else:
|
| 79 |
+
start_image = base64.b64decode(start_image)
|
| 80 |
+
start_image = [Image.open(BytesIO(start_image))]
|
| 81 |
+
|
| 82 |
+
if end_image is not None:
|
| 83 |
+
if end_image.startswith('http'):
|
| 84 |
+
end_image = save_url_image(end_image)
|
| 85 |
+
end_image = [Image.open(end_image)]
|
| 86 |
+
else:
|
| 87 |
+
end_image = base64.b64decode(end_image)
|
| 88 |
+
end_image = [Image.open(BytesIO(end_image))]
|
| 89 |
+
|
| 90 |
+
if validation_video is not None:
|
| 91 |
+
if validation_video.startswith('http'):
|
| 92 |
+
validation_video = save_url_video(validation_video)
|
| 93 |
+
else:
|
| 94 |
+
validation_video = save_base64_video(validation_video)
|
| 95 |
+
|
| 96 |
+
if validation_video_mask is not None:
|
| 97 |
+
if validation_video_mask.startswith('http'):
|
| 98 |
+
validation_video_mask = save_url_image(validation_video_mask)
|
| 99 |
+
else:
|
| 100 |
+
validation_video_mask = save_base64_image(validation_video_mask)
|
| 101 |
+
|
| 102 |
+
if control_video is not None:
|
| 103 |
+
if control_video.startswith('http'):
|
| 104 |
+
control_video = save_url_video(control_video)
|
| 105 |
+
else:
|
| 106 |
+
control_video = save_base64_video(control_video)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
save_sample_path, comment = self.controller.generate(
|
| 110 |
+
"",
|
| 111 |
+
base_model_path,
|
| 112 |
+
lora_model_path,
|
| 113 |
+
lora_alpha_slider,
|
| 114 |
+
prompt_textbox,
|
| 115 |
+
negative_prompt_textbox,
|
| 116 |
+
sampler_dropdown,
|
| 117 |
+
sample_step_slider,
|
| 118 |
+
resize_method,
|
| 119 |
+
width_slider,
|
| 120 |
+
height_slider,
|
| 121 |
+
base_resolution,
|
| 122 |
+
generation_method,
|
| 123 |
+
length_slider,
|
| 124 |
+
overlap_video_length,
|
| 125 |
+
partial_video_length,
|
| 126 |
+
cfg_scale_slider,
|
| 127 |
+
start_image,
|
| 128 |
+
end_image,
|
| 129 |
+
validation_video,
|
| 130 |
+
validation_video_mask,
|
| 131 |
+
control_video,
|
| 132 |
+
denoise_strength,
|
| 133 |
+
seed_textbox,
|
| 134 |
+
is_api = True,
|
| 135 |
+
)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
gc.collect()
|
| 138 |
+
torch.cuda.empty_cache()
|
| 139 |
+
torch.cuda.ipc_collect()
|
| 140 |
+
save_sample_path = ""
|
| 141 |
+
comment = f"Error. error information is {str(e)}"
|
| 142 |
+
return {"message": comment}
|
| 143 |
+
|
| 144 |
+
import torch.distributed as dist
|
| 145 |
+
if dist.get_rank() == 0:
|
| 146 |
+
if save_sample_path != "":
|
| 147 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 148 |
+
else:
|
| 149 |
+
return {"message": comment, "save_sample_path": save_sample_path}
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
self.logger.error(f"Error generating image: {str(e)}")
|
| 154 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 155 |
+
|
| 156 |
+
class MultiNodesEngine:
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
world_size,
|
| 160 |
+
Controller,
|
| 161 |
+
GPU_memory_mode,
|
| 162 |
+
scheduler_dict,
|
| 163 |
+
model_name,
|
| 164 |
+
model_type,
|
| 165 |
+
config_path,
|
| 166 |
+
ulysses_degree,
|
| 167 |
+
ring_degree,
|
| 168 |
+
enable_teacache,
|
| 169 |
+
teacache_threshold,
|
| 170 |
+
num_skip_start_steps,
|
| 171 |
+
teacache_offload,
|
| 172 |
+
weight_dtype,
|
| 173 |
+
savedir_sample
|
| 174 |
+
):
|
| 175 |
+
# Ensure Ray is initialized
|
| 176 |
+
if not ray.is_initialized():
|
| 177 |
+
ray.init()
|
| 178 |
+
|
| 179 |
+
num_workers = world_size
|
| 180 |
+
self.workers = [
|
| 181 |
+
MultiNodesGenerator.remote(
|
| 182 |
+
rank, world_size, Controller,
|
| 183 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
|
| 184 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree, enable_teacache=enable_teacache, teacache_threshold=teacache_threshold, num_skip_start_steps=num_skip_start_steps,
|
| 185 |
+
teacache_offload=teacache_offload, weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 186 |
+
)
|
| 187 |
+
for rank in range(num_workers)
|
| 188 |
+
]
|
| 189 |
+
print("Update workers done")
|
| 190 |
+
|
| 191 |
+
async def generate(self, data):
|
| 192 |
+
results = ray.get([
|
| 193 |
+
worker.generate.remote(data)
|
| 194 |
+
for worker in self.workers
|
| 195 |
+
])
|
| 196 |
+
|
| 197 |
+
return next(path for path in results if path is not None)
|
| 198 |
+
|
| 199 |
+
def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
|
| 200 |
+
|
| 201 |
+
@app.post("/videox_fun/infer_forward")
|
| 202 |
+
async def _multi_nodes_infer_forward_api(
|
| 203 |
+
datas: dict,
|
| 204 |
+
):
|
| 205 |
+
try:
|
| 206 |
+
result = await engine.generate(datas)
|
| 207 |
+
return result
|
| 208 |
+
except Exception as e:
|
| 209 |
+
if isinstance(e, HTTPException):
|
| 210 |
+
raise e
|
| 211 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 212 |
+
else:
|
| 213 |
+
MultiNodesEngine = None
|
| 214 |
+
MultiNodesGenerator = None
|
| 215 |
+
multi_nodes_infer_forward_api = None
|
videox_fun/data/bucket_sampler.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
|
| 5 |
+
Sized, TypeVar, Union)
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torch.utils.data import BatchSampler, Dataset, Sampler
|
| 12 |
+
|
| 13 |
+
ASPECT_RATIO_512 = {
|
| 14 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
| 15 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
| 16 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
| 17 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
| 18 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
| 19 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
| 20 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
| 21 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
| 22 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
| 23 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
| 24 |
+
}
|
| 25 |
+
ASPECT_RATIO_RANDOM_CROP_512 = {
|
| 26 |
+
'0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
|
| 27 |
+
'0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
|
| 28 |
+
'0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
|
| 29 |
+
'1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
|
| 30 |
+
'2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
|
| 31 |
+
}
|
| 32 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = [
|
| 33 |
+
1, 2,
|
| 34 |
+
4, 4, 4, 4,
|
| 35 |
+
8, 8, 8,
|
| 36 |
+
4, 4, 4, 4,
|
| 37 |
+
2, 1
|
| 38 |
+
]
|
| 39 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
|
| 40 |
+
|
| 41 |
+
def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
|
| 42 |
+
aspect_ratio = height / width
|
| 43 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
| 44 |
+
return ratios[closest_ratio], float(closest_ratio)
|
| 45 |
+
|
| 46 |
+
def get_image_size_without_loading(path):
|
| 47 |
+
with Image.open(path) as img:
|
| 48 |
+
return img.size # (width, height)
|
| 49 |
+
|
| 50 |
+
class RandomSampler(Sampler[int]):
|
| 51 |
+
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
|
| 52 |
+
|
| 53 |
+
If with replacement, then user can specify :attr:`num_samples` to draw.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
data_source (Dataset): dataset to sample from
|
| 57 |
+
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
|
| 58 |
+
num_samples (int): number of samples to draw, default=`len(dataset)`.
|
| 59 |
+
generator (Generator): Generator used in sampling.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
data_source: Sized
|
| 63 |
+
replacement: bool
|
| 64 |
+
|
| 65 |
+
def __init__(self, data_source: Sized, replacement: bool = False,
|
| 66 |
+
num_samples: Optional[int] = None, generator=None) -> None:
|
| 67 |
+
self.data_source = data_source
|
| 68 |
+
self.replacement = replacement
|
| 69 |
+
self._num_samples = num_samples
|
| 70 |
+
self.generator = generator
|
| 71 |
+
self._pos_start = 0
|
| 72 |
+
|
| 73 |
+
if not isinstance(self.replacement, bool):
|
| 74 |
+
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
|
| 75 |
+
|
| 76 |
+
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
| 77 |
+
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def num_samples(self) -> int:
|
| 81 |
+
# dataset size might change at runtime
|
| 82 |
+
if self._num_samples is None:
|
| 83 |
+
return len(self.data_source)
|
| 84 |
+
return self._num_samples
|
| 85 |
+
|
| 86 |
+
def __iter__(self) -> Iterator[int]:
|
| 87 |
+
n = len(self.data_source)
|
| 88 |
+
if self.generator is None:
|
| 89 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 90 |
+
generator = torch.Generator()
|
| 91 |
+
generator.manual_seed(seed)
|
| 92 |
+
else:
|
| 93 |
+
generator = self.generator
|
| 94 |
+
|
| 95 |
+
if self.replacement:
|
| 96 |
+
for _ in range(self.num_samples // 32):
|
| 97 |
+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
|
| 98 |
+
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
|
| 99 |
+
else:
|
| 100 |
+
for _ in range(self.num_samples // n):
|
| 101 |
+
xx = torch.randperm(n, generator=generator).tolist()
|
| 102 |
+
if self._pos_start >= n:
|
| 103 |
+
self._pos_start = 0
|
| 104 |
+
print("xx top 10", xx[:10], self._pos_start)
|
| 105 |
+
for idx in range(self._pos_start, n):
|
| 106 |
+
yield xx[idx]
|
| 107 |
+
self._pos_start = (self._pos_start + 1) % n
|
| 108 |
+
self._pos_start = 0
|
| 109 |
+
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
|
| 110 |
+
|
| 111 |
+
def __len__(self) -> int:
|
| 112 |
+
return self.num_samples
|
| 113 |
+
|
| 114 |
+
class AspectRatioBatchImageSampler(BatchSampler):
|
| 115 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
sampler (Sampler): Base sampler.
|
| 119 |
+
dataset (Dataset): Dataset providing data information.
|
| 120 |
+
batch_size (int): Size of mini-batch.
|
| 121 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 122 |
+
its size would be less than ``batch_size``.
|
| 123 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 124 |
+
"""
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
sampler: Sampler,
|
| 128 |
+
dataset: Dataset,
|
| 129 |
+
batch_size: int,
|
| 130 |
+
train_folder: str = None,
|
| 131 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 132 |
+
drop_last: bool = False,
|
| 133 |
+
config=None,
|
| 134 |
+
**kwargs
|
| 135 |
+
) -> None:
|
| 136 |
+
if not isinstance(sampler, Sampler):
|
| 137 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 138 |
+
f'but got {sampler}')
|
| 139 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 140 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 141 |
+
f'but got batch_size={batch_size}')
|
| 142 |
+
self.sampler = sampler
|
| 143 |
+
self.dataset = dataset
|
| 144 |
+
self.train_folder = train_folder
|
| 145 |
+
self.batch_size = batch_size
|
| 146 |
+
self.aspect_ratios = aspect_ratios
|
| 147 |
+
self.drop_last = drop_last
|
| 148 |
+
self.config = config
|
| 149 |
+
# buckets for each aspect ratio
|
| 150 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
| 151 |
+
# [str(k) for k, v in aspect_ratios]
|
| 152 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 153 |
+
|
| 154 |
+
def __iter__(self):
|
| 155 |
+
for idx in self.sampler:
|
| 156 |
+
try:
|
| 157 |
+
image_dict = self.dataset[idx]
|
| 158 |
+
|
| 159 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
| 160 |
+
if width is None or height is None:
|
| 161 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
| 162 |
+
if self.train_folder is None:
|
| 163 |
+
image_dir = image_id
|
| 164 |
+
else:
|
| 165 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
| 166 |
+
|
| 167 |
+
width, height = get_image_size_without_loading(image_dir)
|
| 168 |
+
|
| 169 |
+
ratio = height / width # self.dataset[idx]
|
| 170 |
+
else:
|
| 171 |
+
height = int(height)
|
| 172 |
+
width = int(width)
|
| 173 |
+
ratio = height / width # self.dataset[idx]
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(e)
|
| 176 |
+
continue
|
| 177 |
+
# find the closest aspect ratio
|
| 178 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 179 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 180 |
+
continue
|
| 181 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
| 182 |
+
bucket.append(idx)
|
| 183 |
+
# yield a batch of indices in the same aspect ratio group
|
| 184 |
+
if len(bucket) == self.batch_size:
|
| 185 |
+
yield bucket[:]
|
| 186 |
+
del bucket[:]
|
| 187 |
+
|
| 188 |
+
class AspectRatioBatchSampler(BatchSampler):
|
| 189 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
sampler (Sampler): Base sampler.
|
| 193 |
+
dataset (Dataset): Dataset providing data information.
|
| 194 |
+
batch_size (int): Size of mini-batch.
|
| 195 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 196 |
+
its size would be less than ``batch_size``.
|
| 197 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 198 |
+
"""
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
sampler: Sampler,
|
| 202 |
+
dataset: Dataset,
|
| 203 |
+
batch_size: int,
|
| 204 |
+
video_folder: str = None,
|
| 205 |
+
train_data_format: str = "webvid",
|
| 206 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 207 |
+
drop_last: bool = False,
|
| 208 |
+
config=None,
|
| 209 |
+
**kwargs
|
| 210 |
+
) -> None:
|
| 211 |
+
if not isinstance(sampler, Sampler):
|
| 212 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 213 |
+
f'but got {sampler}')
|
| 214 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 215 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 216 |
+
f'but got batch_size={batch_size}')
|
| 217 |
+
self.sampler = sampler
|
| 218 |
+
self.dataset = dataset
|
| 219 |
+
self.video_folder = video_folder
|
| 220 |
+
self.train_data_format = train_data_format
|
| 221 |
+
self.batch_size = batch_size
|
| 222 |
+
self.aspect_ratios = aspect_ratios
|
| 223 |
+
self.drop_last = drop_last
|
| 224 |
+
self.config = config
|
| 225 |
+
# buckets for each aspect ratio
|
| 226 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
| 227 |
+
# [str(k) for k, v in aspect_ratios]
|
| 228 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 229 |
+
|
| 230 |
+
def __iter__(self):
|
| 231 |
+
for idx in self.sampler:
|
| 232 |
+
try:
|
| 233 |
+
video_dict = self.dataset[idx]
|
| 234 |
+
width, more = video_dict.get("width", None), video_dict.get("height", None)
|
| 235 |
+
|
| 236 |
+
if width is None or height is None:
|
| 237 |
+
if self.train_data_format == "normal":
|
| 238 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 239 |
+
if self.video_folder is None:
|
| 240 |
+
video_dir = video_id
|
| 241 |
+
else:
|
| 242 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 243 |
+
else:
|
| 244 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
| 245 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
| 246 |
+
cap = cv2.VideoCapture(video_dir)
|
| 247 |
+
|
| 248 |
+
# 获取视频尺寸
|
| 249 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
|
| 250 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
|
| 251 |
+
|
| 252 |
+
ratio = height / width # self.dataset[idx]
|
| 253 |
+
else:
|
| 254 |
+
height = int(height)
|
| 255 |
+
width = int(width)
|
| 256 |
+
ratio = height / width # self.dataset[idx]
|
| 257 |
+
except Exception as e:
|
| 258 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 259 |
+
continue
|
| 260 |
+
# find the closest aspect ratio
|
| 261 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 262 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 263 |
+
continue
|
| 264 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
| 265 |
+
bucket.append(idx)
|
| 266 |
+
# yield a batch of indices in the same aspect ratio group
|
| 267 |
+
if len(bucket) == self.batch_size:
|
| 268 |
+
yield bucket[:]
|
| 269 |
+
del bucket[:]
|
| 270 |
+
|
| 271 |
+
class AspectRatioBatchImageVideoSampler(BatchSampler):
|
| 272 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
sampler (Sampler): Base sampler.
|
| 276 |
+
dataset (Dataset): Dataset providing data information.
|
| 277 |
+
batch_size (int): Size of mini-batch.
|
| 278 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 279 |
+
its size would be less than ``batch_size``.
|
| 280 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__(self,
|
| 284 |
+
sampler: Sampler,
|
| 285 |
+
dataset: Dataset,
|
| 286 |
+
batch_size: int,
|
| 287 |
+
train_folder: str = None,
|
| 288 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 289 |
+
drop_last: bool = False
|
| 290 |
+
) -> None:
|
| 291 |
+
if not isinstance(sampler, Sampler):
|
| 292 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 293 |
+
f'but got {sampler}')
|
| 294 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 295 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 296 |
+
f'but got batch_size={batch_size}')
|
| 297 |
+
self.sampler = sampler
|
| 298 |
+
self.dataset = dataset
|
| 299 |
+
self.train_folder = train_folder
|
| 300 |
+
self.batch_size = batch_size
|
| 301 |
+
self.aspect_ratios = aspect_ratios
|
| 302 |
+
self.drop_last = drop_last
|
| 303 |
+
|
| 304 |
+
# buckets for each aspect ratio
|
| 305 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 306 |
+
self.bucket = {
|
| 307 |
+
'image':{ratio: [] for ratio in aspect_ratios},
|
| 308 |
+
'video':{ratio: [] for ratio in aspect_ratios}
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
def __iter__(self):
|
| 312 |
+
for idx in self.sampler:
|
| 313 |
+
content_type = self.dataset[idx].get('type', 'image')
|
| 314 |
+
if content_type == 'image':
|
| 315 |
+
try:
|
| 316 |
+
image_dict = self.dataset[idx]
|
| 317 |
+
|
| 318 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
| 319 |
+
if width is None or height is None:
|
| 320 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
| 321 |
+
if self.train_folder is None:
|
| 322 |
+
image_dir = image_id
|
| 323 |
+
else:
|
| 324 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
| 325 |
+
|
| 326 |
+
width, height = get_image_size_without_loading(image_dir)
|
| 327 |
+
|
| 328 |
+
ratio = height / width # self.dataset[idx]
|
| 329 |
+
else:
|
| 330 |
+
height = int(height)
|
| 331 |
+
width = int(width)
|
| 332 |
+
ratio = height / width # self.dataset[idx]
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 335 |
+
continue
|
| 336 |
+
# find the closest aspect ratio
|
| 337 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 338 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 339 |
+
continue
|
| 340 |
+
bucket = self.bucket['image'][closest_ratio]
|
| 341 |
+
bucket.append(idx)
|
| 342 |
+
# yield a batch of indices in the same aspect ratio group
|
| 343 |
+
if len(bucket) == self.batch_size:
|
| 344 |
+
yield bucket[:]
|
| 345 |
+
del bucket[:]
|
| 346 |
+
else:
|
| 347 |
+
try:
|
| 348 |
+
video_dict = self.dataset[idx]
|
| 349 |
+
width, height = video_dict.get("width", None), video_dict.get("height", None)
|
| 350 |
+
|
| 351 |
+
if width is None or height is None:
|
| 352 |
+
if video_dict['type'] == 'video_mask_tuple':
|
| 353 |
+
video_dir = video_dict['file_path']
|
| 354 |
+
if os.path.isdir(os.path.join(video_dir, 'input')):
|
| 355 |
+
sample_path = list(glob.glob(os.path.join(video_dir, 'input', '*.png')))[0]
|
| 356 |
+
width, height = get_image_size_without_loading(sample_path)
|
| 357 |
+
else:
|
| 358 |
+
sample_path = os.path.join(video_dir, 'rgb_full.mp4')
|
| 359 |
+
cap = cv2.VideoCapture(sample_path)
|
| 360 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 361 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 362 |
+
else:
|
| 363 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 364 |
+
if self.train_folder is None:
|
| 365 |
+
video_dir = video_id
|
| 366 |
+
else:
|
| 367 |
+
video_dir = os.path.join(self.train_folder, video_id)
|
| 368 |
+
cap = cv2.VideoCapture(video_dir)
|
| 369 |
+
|
| 370 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 371 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 372 |
+
|
| 373 |
+
ratio = height / width # self.dataset[idx]
|
| 374 |
+
else:
|
| 375 |
+
height = int(height)
|
| 376 |
+
width = int(width)
|
| 377 |
+
ratio = height / width # self.dataset[idx]
|
| 378 |
+
except Exception as e:
|
| 379 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 380 |
+
continue
|
| 381 |
+
# find the closest aspect ratio
|
| 382 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 383 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 384 |
+
continue
|
| 385 |
+
bucket = self.bucket['video'][closest_ratio]
|
| 386 |
+
bucket.append(idx)
|
| 387 |
+
# yield a batch of indices in the same aspect ratio group
|
| 388 |
+
if len(bucket) == self.batch_size:
|
| 389 |
+
yield bucket[:]
|
| 390 |
+
del bucket[:]
|
videox_fun/data/dataset_image.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data.dataset import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CC15M(Dataset):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
json_path,
|
| 16 |
+
video_folder=None,
|
| 17 |
+
resolution=512,
|
| 18 |
+
enable_bucket=False,
|
| 19 |
+
):
|
| 20 |
+
print(f"loading annotations from {json_path} ...")
|
| 21 |
+
self.dataset = json.load(open(json_path, 'r'))
|
| 22 |
+
self.length = len(self.dataset)
|
| 23 |
+
print(f"data scale: {self.length}")
|
| 24 |
+
|
| 25 |
+
self.enable_bucket = enable_bucket
|
| 26 |
+
self.video_folder = video_folder
|
| 27 |
+
|
| 28 |
+
resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
|
| 29 |
+
self.pixel_transforms = transforms.Compose([
|
| 30 |
+
transforms.Resize(resolution[0]),
|
| 31 |
+
transforms.CenterCrop(resolution),
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
def get_batch(self, idx):
|
| 37 |
+
video_dict = self.dataset[idx]
|
| 38 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 39 |
+
|
| 40 |
+
if self.video_folder is None:
|
| 41 |
+
video_dir = video_id
|
| 42 |
+
else:
|
| 43 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 44 |
+
|
| 45 |
+
pixel_values = Image.open(video_dir).convert("RGB")
|
| 46 |
+
return pixel_values, name
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return self.length
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
while True:
|
| 53 |
+
try:
|
| 54 |
+
pixel_values, name = self.get_batch(idx)
|
| 55 |
+
break
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(e)
|
| 58 |
+
idx = random.randint(0, self.length-1)
|
| 59 |
+
|
| 60 |
+
if not self.enable_bucket:
|
| 61 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 62 |
+
else:
|
| 63 |
+
pixel_values = np.array(pixel_values)
|
| 64 |
+
|
| 65 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 66 |
+
return sample
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
dataset = CC15M(
|
| 70 |
+
csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
|
| 71 |
+
resolution=512,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
| 75 |
+
for idx, batch in enumerate(dataloader):
|
| 76 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
videox_fun/data/dataset_image_video.py
ADDED
|
@@ -0,0 +1,1067 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import glob
|
| 7 |
+
import random
|
| 8 |
+
from threading import Thread
|
| 9 |
+
import mediapy as media
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import albumentations
|
| 13 |
+
import cv2
|
| 14 |
+
import gc
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from scipy.special import binom
|
| 19 |
+
|
| 20 |
+
from func_timeout import func_timeout, FunctionTimedOut
|
| 21 |
+
from decord import VideoReader
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 24 |
+
from torch.utils.data.dataset import Dataset
|
| 25 |
+
from contextlib import contextmanager
|
| 26 |
+
|
| 27 |
+
VIDEO_READER_TIMEOUT = 20
|
| 28 |
+
|
| 29 |
+
bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)
|
| 30 |
+
|
| 31 |
+
# codes from https://stackoverflow.com/questions/50731785/create-random-shape-contour-using-matplotlib
|
| 32 |
+
def bezier(points, num=200):
|
| 33 |
+
N = len(points)
|
| 34 |
+
t = np.linspace(0, 1, num=num)
|
| 35 |
+
curve = np.zeros((num, 2))
|
| 36 |
+
for i in range(N):
|
| 37 |
+
curve += np.outer(bernstein(N - 1, i, t), points[i])
|
| 38 |
+
return curve
|
| 39 |
+
|
| 40 |
+
class Segment():
|
| 41 |
+
def __init__(self, p1, p2, angle1, angle2, **kw):
|
| 42 |
+
self.p1 = p1
|
| 43 |
+
self.p2 = p2
|
| 44 |
+
self.angle1 = angle1
|
| 45 |
+
self.angle2 = angle2
|
| 46 |
+
self.numpoints = kw.get("numpoints", 100)
|
| 47 |
+
r = kw.get("r", 0.3)
|
| 48 |
+
d = np.sqrt(np.sum((self.p2-self.p1)**2))
|
| 49 |
+
self.r = r*d
|
| 50 |
+
self.p = np.zeros((4,2))
|
| 51 |
+
self.p[0,:] = self.p1[:]
|
| 52 |
+
self.p[3,:] = self.p2[:]
|
| 53 |
+
self.calc_intermediate_points(self.r)
|
| 54 |
+
|
| 55 |
+
def calc_intermediate_points(self,r):
|
| 56 |
+
self.p[1,:] = self.p1 + np.array(
|
| 57 |
+
[self.r*np.cos(self.angle1), self.r*np.sin(self.angle1)])
|
| 58 |
+
self.p[2,:] = self.p2 + np.array(
|
| 59 |
+
[self.r*np.cos(self.angle2+np.pi), self.r*np.sin(self.angle2+np.pi)])
|
| 60 |
+
self.curve = bezier(self.p,self.numpoints)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_curve(points, **kw):
|
| 64 |
+
segments = []
|
| 65 |
+
for i in range(len(points)-1):
|
| 66 |
+
seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
|
| 67 |
+
segments.append(seg)
|
| 68 |
+
curve = np.concatenate([s.curve for s in segments])
|
| 69 |
+
return segments, curve
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def ccw_sort(p):
|
| 73 |
+
d = p-np.mean(p,axis=0)
|
| 74 |
+
s = np.arctan2(d[:,0], d[:,1])
|
| 75 |
+
return p[np.argsort(s),:]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_bezier_curve(a, rad=0.2, edgy=0):
|
| 79 |
+
""" given an array of points *a*, create a curve through
|
| 80 |
+
those points.
|
| 81 |
+
*rad* is a number between 0 and 1 to steer the distance of
|
| 82 |
+
control points.
|
| 83 |
+
*edgy* is a parameter which controls how "edgy" the curve is,
|
| 84 |
+
edgy=0 is smoothest."""
|
| 85 |
+
p = np.arctan(edgy)/np.pi+.5
|
| 86 |
+
a = ccw_sort(a)
|
| 87 |
+
a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
|
| 88 |
+
d = np.diff(a, axis=0)
|
| 89 |
+
ang = np.arctan2(d[:,1],d[:,0])
|
| 90 |
+
f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
|
| 91 |
+
ang = f(ang)
|
| 92 |
+
ang1 = ang
|
| 93 |
+
ang2 = np.roll(ang,1)
|
| 94 |
+
ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
|
| 95 |
+
ang = np.append(ang, [ang[0]])
|
| 96 |
+
a = np.append(a, np.atleast_2d(ang).T, axis=1)
|
| 97 |
+
s, c = get_curve(a, r=rad, method="var")
|
| 98 |
+
x,y = c.T
|
| 99 |
+
return x,y, a
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_random_points(n=5, scale=0.8, mindst=None, rec=0):
|
| 103 |
+
""" create n random points in the unit square, which are *mindst*
|
| 104 |
+
apart, then scale them."""
|
| 105 |
+
mindst = mindst or .7/n
|
| 106 |
+
a = np.random.rand(n,2)
|
| 107 |
+
d = np.sqrt(np.sum(np.diff(ccw_sort(a), axis=0), axis=1)**2)
|
| 108 |
+
if np.all(d >= mindst) or rec>=200:
|
| 109 |
+
return a*scale
|
| 110 |
+
else:
|
| 111 |
+
return get_random_points(n=n, scale=scale, mindst=mindst, rec=rec+1)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def fill_mask(shape, x, y, fill_val=255):
|
| 115 |
+
_, _, h, w = shape
|
| 116 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 117 |
+
mask = cv2.fillPoly(mask, [np.array([x, y], np.int32).T], fill_val)
|
| 118 |
+
return mask
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def random_shift(x, y, scale_range = [0.2, 0.7], trans_perturb_range=[-0.2, 0.2]):
|
| 122 |
+
w_scale = np.random.uniform(scale_range[0], scale_range[1])
|
| 123 |
+
h_scale = np.random.uniform(scale_range[0], scale_range[1])
|
| 124 |
+
x_trans = np.random.uniform(0., 1. - w_scale)
|
| 125 |
+
y_trans = np.random.uniform(0., 1. - h_scale)
|
| 126 |
+
x_shifted = x * w_scale + x_trans + np.random.uniform(trans_perturb_range[0], trans_perturb_range[1])
|
| 127 |
+
y_shifted = y * h_scale + y_trans + np.random.uniform(trans_perturb_range[0], trans_perturb_range[1])
|
| 128 |
+
return x_shifted, y_shifted
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_random_shape_mask(
|
| 132 |
+
shape, n_pts_range=[3, 10], rad_range=[0.0, 1.0], edgy_range=[0.0, 0.1], n_keyframes_range=[2, 25],
|
| 133 |
+
random_drop_range=[0.0, 0.2],
|
| 134 |
+
):
|
| 135 |
+
f, _, h, w = shape
|
| 136 |
+
|
| 137 |
+
n_pts = np.random.randint(n_pts_range[0], n_pts_range[1])
|
| 138 |
+
n_keyframes = np.random.randint(n_keyframes_range[0], n_keyframes_range[1])
|
| 139 |
+
keyframe_interval = f // (n_keyframes - 1)
|
| 140 |
+
keyframe_indices = list(range(0, f, keyframe_interval))
|
| 141 |
+
if len(keyframe_indices) == n_keyframes:
|
| 142 |
+
keyframe_indices[-1] = f - 1
|
| 143 |
+
else:
|
| 144 |
+
keyframe_indices.append(f - 1)
|
| 145 |
+
x_all_frames, y_all_frames = [], []
|
| 146 |
+
for i, keyframe_index in enumerate(keyframe_indices):
|
| 147 |
+
rad = np.random.uniform(rad_range[0], rad_range[1])
|
| 148 |
+
edgy = np.random.uniform(edgy_range[0], edgy_range[1])
|
| 149 |
+
x_kf, y_kf, _ = get_bezier_curve(get_random_points(n=n_pts), rad=rad, edgy=edgy)
|
| 150 |
+
x_kf, y_kf = random_shift(x_kf, y_kf)
|
| 151 |
+
if i == 0:
|
| 152 |
+
x_all_frames.append(x_kf[None])
|
| 153 |
+
y_all_frames.append(y_kf[None])
|
| 154 |
+
else:
|
| 155 |
+
x_interval = np.linspace(x_all_frames[-1][-1], x_kf, keyframe_index - keyframe_indices[i - 1] + 1)
|
| 156 |
+
y_interval = np.linspace(y_all_frames[-1][-1], y_kf, keyframe_index - keyframe_indices[i - 1] + 1)
|
| 157 |
+
x_all_frames.append(x_interval[1:])
|
| 158 |
+
y_all_frames.append(y_interval[1:])
|
| 159 |
+
x_all_frames = np.concatenate(x_all_frames, axis=0)
|
| 160 |
+
y_all_frames = np.concatenate(y_all_frames, axis=0)
|
| 161 |
+
|
| 162 |
+
masks = []
|
| 163 |
+
for x, y in zip(x_all_frames, y_all_frames):
|
| 164 |
+
x = np.round(x * w).astype(np.int32)
|
| 165 |
+
y = np.round(y * h).astype(np.int32)
|
| 166 |
+
mask = fill_mask(shape, x, y)
|
| 167 |
+
masks.append(mask)
|
| 168 |
+
masks = np.stack(masks, axis=0).astype(float) / 255.
|
| 169 |
+
|
| 170 |
+
n_frames_random_drop = int(np.random.uniform(random_drop_range[0], random_drop_range[1]) * f)
|
| 171 |
+
drop_index = np.random.randint(0, f - n_frames_random_drop)
|
| 172 |
+
masks[drop_index:drop_index + n_frames_random_drop] = 0
|
| 173 |
+
|
| 174 |
+
return masks # (f, h, w), <float>[0, 1]
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_random_mask(shape, mask_type_probs=[0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.8]):
|
| 178 |
+
f, c, h, w = shape
|
| 179 |
+
|
| 180 |
+
if f != 1:
|
| 181 |
+
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], p=mask_type_probs)
|
| 182 |
+
else:
|
| 183 |
+
mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
|
| 184 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
| 185 |
+
|
| 186 |
+
if mask_index == 0:
|
| 187 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 188 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 189 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()
|
| 190 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()
|
| 191 |
+
|
| 192 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 193 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 194 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 195 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 196 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
| 197 |
+
elif mask_index == 1:
|
| 198 |
+
mask[:, :, :, :] = 1
|
| 199 |
+
elif mask_index == 2:
|
| 200 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 201 |
+
mask[mask_frame_index:, :, :, :] = 1
|
| 202 |
+
elif mask_index == 3:
|
| 203 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 204 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
| 205 |
+
elif mask_index == 4:
|
| 206 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 207 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 208 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()
|
| 209 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()
|
| 210 |
+
|
| 211 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 212 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 213 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 214 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 215 |
+
|
| 216 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
| 217 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
| 218 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
| 219 |
+
elif mask_index == 5:
|
| 220 |
+
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
|
| 221 |
+
elif mask_index == 6:
|
| 222 |
+
num_frames_to_mask = random.randint(1, max(f // 2, 1))
|
| 223 |
+
frames_to_mask = random.sample(range(f), num_frames_to_mask)
|
| 224 |
+
|
| 225 |
+
for i in frames_to_mask:
|
| 226 |
+
block_height = random.randint(1, h // 4)
|
| 227 |
+
block_width = random.randint(1, w // 4)
|
| 228 |
+
top_left_y = random.randint(0, h - block_height)
|
| 229 |
+
top_left_x = random.randint(0, w - block_width)
|
| 230 |
+
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
|
| 231 |
+
elif mask_index == 7:
|
| 232 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 233 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 234 |
+
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item()
|
| 235 |
+
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
| 236 |
+
|
| 237 |
+
for i in range(h):
|
| 238 |
+
for j in range(w):
|
| 239 |
+
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
|
| 240 |
+
mask[:, :, i, j] = 1
|
| 241 |
+
elif mask_index == 8:
|
| 242 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 243 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 244 |
+
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
| 245 |
+
for i in range(h):
|
| 246 |
+
for j in range(w):
|
| 247 |
+
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
|
| 248 |
+
mask[:, :, i, j] = 1
|
| 249 |
+
elif mask_index == 9:
|
| 250 |
+
for idx in range(f):
|
| 251 |
+
if np.random.rand() > 0.5:
|
| 252 |
+
mask[idx, :, :, :] = 1
|
| 253 |
+
else:
|
| 254 |
+
num_objs = np.random.randint(1, 4)
|
| 255 |
+
mask_npy = get_random_shape_mask(shape)
|
| 256 |
+
for i in range(num_objs - 1):
|
| 257 |
+
mask_npy += get_random_shape_mask(shape).clip(0, 1)
|
| 258 |
+
|
| 259 |
+
mask = torch.from_numpy(mask_npy).unsqueeze(1)
|
| 260 |
+
|
| 261 |
+
return mask.float()
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def get_random_mask_multi(shape, mask_type_probs, range_num_masks=[1, 7]):
|
| 265 |
+
num_masks = np.random.randint(range_num_masks[0], range_num_masks[1])
|
| 266 |
+
masks = None
|
| 267 |
+
for _ in range(num_masks):
|
| 268 |
+
mask = get_random_mask(shape, mask_type_probs)
|
| 269 |
+
if masks is None:
|
| 270 |
+
masks = mask
|
| 271 |
+
else:
|
| 272 |
+
masks = (masks + mask).clip(0, 1)
|
| 273 |
+
return masks
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class ImageVideoSampler(BatchSampler):
|
| 277 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
sampler (Sampler): Base sampler.
|
| 281 |
+
dataset (Dataset): Dataset providing data information.
|
| 282 |
+
batch_size (int): Size of mini-batch.
|
| 283 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 284 |
+
its size would be less than ``batch_size``.
|
| 285 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
def __init__(self,
|
| 289 |
+
sampler: Sampler,
|
| 290 |
+
dataset: Dataset,
|
| 291 |
+
batch_size: int,
|
| 292 |
+
drop_last: bool = False
|
| 293 |
+
) -> None:
|
| 294 |
+
if not isinstance(sampler, Sampler):
|
| 295 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 296 |
+
f'but got {sampler}')
|
| 297 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 298 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 299 |
+
f'but got batch_size={batch_size}')
|
| 300 |
+
self.sampler = sampler
|
| 301 |
+
self.dataset = dataset
|
| 302 |
+
self.batch_size = batch_size
|
| 303 |
+
self.drop_last = drop_last
|
| 304 |
+
|
| 305 |
+
# buckets for each aspect ratio
|
| 306 |
+
self.bucket = {'image':[], 'video':[], 'video_mask_tuple':[]}
|
| 307 |
+
|
| 308 |
+
def __iter__(self):
|
| 309 |
+
for idx in self.sampler:
|
| 310 |
+
content_type = self.dataset.dataset[idx].get('type', 'image')
|
| 311 |
+
self.bucket[content_type].append(idx)
|
| 312 |
+
|
| 313 |
+
# yield a batch of indices in the same aspect ratio group
|
| 314 |
+
if len(self.bucket['video']) == self.batch_size:
|
| 315 |
+
bucket = self.bucket['video']
|
| 316 |
+
yield bucket[:]
|
| 317 |
+
del bucket[:]
|
| 318 |
+
elif len(self.bucket['video_mask_tuple']) == self.batch_size:
|
| 319 |
+
bucket = self.bucket['video_mask_tuple']
|
| 320 |
+
yield bucket[:]
|
| 321 |
+
del bucket[:]
|
| 322 |
+
elif len(self.bucket['image']) == self.batch_size:
|
| 323 |
+
bucket = self.bucket['image']
|
| 324 |
+
yield bucket[:]
|
| 325 |
+
del bucket[:]
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@contextmanager
|
| 329 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
| 330 |
+
vr = VideoReader(*args, **kwargs)
|
| 331 |
+
try:
|
| 332 |
+
yield vr
|
| 333 |
+
finally:
|
| 334 |
+
del vr
|
| 335 |
+
gc.collect()
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def get_video_reader_batch(video_reader, batch_index):
|
| 339 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
| 340 |
+
return frames
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _read_video_from_dir(video_dir):
|
| 344 |
+
frames = []
|
| 345 |
+
frame_paths = sorted(list(glob.glob(os.path.join(video_dir, '*.png'))))
|
| 346 |
+
|
| 347 |
+
if not frame_paths:
|
| 348 |
+
raise ValueError(f"No PNG files found in directory: {video_dir}")
|
| 349 |
+
|
| 350 |
+
for frame_path in frame_paths:
|
| 351 |
+
frame = media.read_image(frame_path)
|
| 352 |
+
frames.append(frame)
|
| 353 |
+
|
| 354 |
+
if not frames:
|
| 355 |
+
raise ValueError(f"Failed to read any frames from directory: {video_dir}")
|
| 356 |
+
|
| 357 |
+
return np.stack(frames, axis=0)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def resize_frame(frame, target_short_side):
|
| 361 |
+
h, w, _ = frame.shape
|
| 362 |
+
if h < w:
|
| 363 |
+
if target_short_side > h:
|
| 364 |
+
return frame
|
| 365 |
+
new_h = target_short_side
|
| 366 |
+
new_w = int(target_short_side * w / h)
|
| 367 |
+
else:
|
| 368 |
+
if target_short_side > w:
|
| 369 |
+
return frame
|
| 370 |
+
new_w = target_short_side
|
| 371 |
+
new_h = int(target_short_side * h / w)
|
| 372 |
+
|
| 373 |
+
resized_frame = cv2.resize(frame, (new_w, new_h))
|
| 374 |
+
return resized_frame
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class ImageVideoDataset(Dataset):
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
ann_path, data_root=None,
|
| 381 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 382 |
+
image_sample_size=512,
|
| 383 |
+
video_repeat=0,
|
| 384 |
+
text_drop_ratio=0.1,
|
| 385 |
+
enable_bucket=False,
|
| 386 |
+
video_length_drop_start=0.0,
|
| 387 |
+
video_length_drop_end=1.0,
|
| 388 |
+
enable_inpaint=False,
|
| 389 |
+
trimask_zeroout_removal=False,
|
| 390 |
+
use_quadmask=False,
|
| 391 |
+
ablation_binary_mask=False,
|
| 392 |
+
):
|
| 393 |
+
# Loading annotations from files
|
| 394 |
+
print(f"loading annotations from {ann_path} ...")
|
| 395 |
+
if ann_path.endswith('.csv'):
|
| 396 |
+
with open(ann_path, 'r') as csvfile:
|
| 397 |
+
dataset = list(csv.DictReader(csvfile))
|
| 398 |
+
elif ann_path.endswith('.json'):
|
| 399 |
+
dataset = json.load(open(ann_path))
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError(f"Unsupported annotation file format: {ann_path}. Only .csv and .json files are supported.")
|
| 402 |
+
|
| 403 |
+
self.data_root = data_root
|
| 404 |
+
|
| 405 |
+
# It's used to balance num of images and videos.
|
| 406 |
+
self.dataset = []
|
| 407 |
+
for data in dataset:
|
| 408 |
+
if data.get('type', 'image') != 'video':
|
| 409 |
+
self.dataset.append(data)
|
| 410 |
+
if video_repeat > 0:
|
| 411 |
+
for _ in range(video_repeat):
|
| 412 |
+
for data in dataset:
|
| 413 |
+
if data.get('type', 'image') == 'video':
|
| 414 |
+
self.dataset.append(data)
|
| 415 |
+
del dataset
|
| 416 |
+
|
| 417 |
+
self.length = len(self.dataset)
|
| 418 |
+
print(f"data scale: {self.length}")
|
| 419 |
+
# TODO: enable bucket training
|
| 420 |
+
self.enable_bucket = enable_bucket
|
| 421 |
+
self.text_drop_ratio = text_drop_ratio
|
| 422 |
+
self.enable_inpaint = enable_inpaint
|
| 423 |
+
self.trimask_zeroout_removal = trimask_zeroout_removal
|
| 424 |
+
self.use_quadmask = use_quadmask
|
| 425 |
+
self.ablation_binary_mask = ablation_binary_mask
|
| 426 |
+
|
| 427 |
+
self.video_length_drop_start = video_length_drop_start
|
| 428 |
+
self.video_length_drop_end = video_length_drop_end
|
| 429 |
+
|
| 430 |
+
if self.use_quadmask:
|
| 431 |
+
print(f"[QUADMASK MODE] Using 4-value quadmask: [0, 63, 127, 255]")
|
| 432 |
+
if self.ablation_binary_mask:
|
| 433 |
+
print(f"[ABLATION BINARY MASK] Remapping quadmask to binary: [0,63]→0, [127,255]→127")
|
| 434 |
+
else:
|
| 435 |
+
print(f"[TRIMASK MODE] Using 3-value trimask: [0, 127, 255]")
|
| 436 |
+
|
| 437 |
+
# Video params
|
| 438 |
+
self.video_sample_stride = video_sample_stride
|
| 439 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 440 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 441 |
+
self.video_transforms = transforms.Compose(
|
| 442 |
+
[
|
| 443 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 444 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 445 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 446 |
+
]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# Image params
|
| 450 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 451 |
+
self.image_transforms = transforms.Compose([
|
| 452 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 453 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 454 |
+
transforms.ToTensor(),
|
| 455 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 456 |
+
])
|
| 457 |
+
|
| 458 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 459 |
+
|
| 460 |
+
def get_batch(self, idx):
|
| 461 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 462 |
+
|
| 463 |
+
if data_info.get('type', 'image') == 'video' and data_info.get('mask_path', None) is None:
|
| 464 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 465 |
+
|
| 466 |
+
if self.data_root is None:
|
| 467 |
+
video_dir = video_id
|
| 468 |
+
else:
|
| 469 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 470 |
+
|
| 471 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 472 |
+
min_sample_n_frames = min(
|
| 473 |
+
self.video_sample_n_frames,
|
| 474 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 475 |
+
)
|
| 476 |
+
if min_sample_n_frames == 0:
|
| 477 |
+
raise ValueError(f"No Frames in video.")
|
| 478 |
+
|
| 479 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 480 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 481 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 482 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 483 |
+
|
| 484 |
+
try:
|
| 485 |
+
sample_args = (video_reader, batch_index)
|
| 486 |
+
pixel_values = func_timeout(
|
| 487 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 488 |
+
)
|
| 489 |
+
resized_frames = []
|
| 490 |
+
for i in range(len(pixel_values)):
|
| 491 |
+
frame = pixel_values[i]
|
| 492 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 493 |
+
resized_frames.append(resized_frame)
|
| 494 |
+
pixel_values = np.array(resized_frames)
|
| 495 |
+
except FunctionTimedOut:
|
| 496 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 497 |
+
except Exception as e:
|
| 498 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 499 |
+
|
| 500 |
+
if not self.enable_bucket:
|
| 501 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 502 |
+
pixel_values = pixel_values / 255.
|
| 503 |
+
del video_reader
|
| 504 |
+
else:
|
| 505 |
+
pixel_values = pixel_values
|
| 506 |
+
|
| 507 |
+
if not self.enable_bucket:
|
| 508 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 509 |
+
|
| 510 |
+
# Random use no text generation
|
| 511 |
+
if random.random() < self.text_drop_ratio:
|
| 512 |
+
text = ''
|
| 513 |
+
return {
|
| 514 |
+
'pixel_values': pixel_values,
|
| 515 |
+
'text': text,
|
| 516 |
+
'data_type': 'video',
|
| 517 |
+
}
|
| 518 |
+
elif data_info.get('type', 'image') == 'video' and data_info.get('mask_path', None) is not None: # video with known mask
|
| 519 |
+
video_path, text = data_info['file_path'], data_info['text']
|
| 520 |
+
mask_video_path = video_path[:-4] + '_mask.mp4'
|
| 521 |
+
with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
|
| 522 |
+
min_sample_n_frames = min(
|
| 523 |
+
self.video_sample_n_frames,
|
| 524 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 525 |
+
)
|
| 526 |
+
if min_sample_n_frames == 0:
|
| 527 |
+
raise ValueError(f"No Frames in video.")
|
| 528 |
+
|
| 529 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 530 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 531 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 532 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 533 |
+
|
| 534 |
+
try:
|
| 535 |
+
sample_args = (video_reader, batch_index)
|
| 536 |
+
pixel_values = func_timeout(
|
| 537 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 538 |
+
)
|
| 539 |
+
resized_frames = []
|
| 540 |
+
for i in range(len(pixel_values)):
|
| 541 |
+
frame = pixel_values[i]
|
| 542 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 543 |
+
resized_frames.append(resized_frame)
|
| 544 |
+
input_video = np.array(resized_frames)
|
| 545 |
+
except FunctionTimedOut:
|
| 546 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 547 |
+
except Exception as e:
|
| 548 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 549 |
+
|
| 550 |
+
with VideoReader_contextmanager(mask_video_path, num_threads=2) as video_reader:
|
| 551 |
+
try:
|
| 552 |
+
sample_args = (video_reader, batch_index)
|
| 553 |
+
mask_values = func_timeout(
|
| 554 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 555 |
+
)
|
| 556 |
+
resized_frames = []
|
| 557 |
+
for i in range(len(mask_values)):
|
| 558 |
+
frame = mask_values[i]
|
| 559 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 560 |
+
resized_frames.append(resized_frame)
|
| 561 |
+
mask_video = np.array(resized_frames)
|
| 562 |
+
except FunctionTimedOut:
|
| 563 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 564 |
+
except Exception as e:
|
| 565 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 566 |
+
|
| 567 |
+
if len(mask_video.shape) == 3:
|
| 568 |
+
mask_video = mask_video[..., None]
|
| 569 |
+
if mask_video.shape[-1] == 3:
|
| 570 |
+
mask_video = mask_video[..., :1]
|
| 571 |
+
if len(mask_video.shape) != 4:
|
| 572 |
+
raise ValueError(f"mask_video shape is {mask_video.shape}.")
|
| 573 |
+
|
| 574 |
+
text = data_info['text']
|
| 575 |
+
if not self.enable_bucket:
|
| 576 |
+
input_video = torch.from_numpy(input_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 577 |
+
mask_video = torch.from_numpy(mask_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 578 |
+
|
| 579 |
+
pixel_values = torch.cat([input_video, mask_video], dim=1)
|
| 580 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 581 |
+
input_video = pixel_values[:, :3]
|
| 582 |
+
mask_video = pixel_values[:, 3:]
|
| 583 |
+
|
| 584 |
+
# Random use no text generation
|
| 585 |
+
if random.random() < self.text_drop_ratio:
|
| 586 |
+
text = ''
|
| 587 |
+
|
| 588 |
+
return {
|
| 589 |
+
'pixel_values': input_video,
|
| 590 |
+
'mask': mask_video,
|
| 591 |
+
'text': text,
|
| 592 |
+
'data_type': 'video',
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
elif data_info.get('type', 'image') == 'video_mask_tuple': # object effect removal
|
| 596 |
+
sample_dir = data_info['file_path']
|
| 597 |
+
try:
|
| 598 |
+
if os.path.exists(os.path.join(sample_dir, 'rgb_full.mp4')):
|
| 599 |
+
input_video_path = os.path.join(sample_dir, 'rgb_full.mp4')
|
| 600 |
+
target_video_path = os.path.join(sample_dir, 'rgb_removed.mp4')
|
| 601 |
+
mask_video_path = os.path.join(sample_dir, 'mask.mp4')
|
| 602 |
+
depth_video_path = os.path.join(sample_dir, 'depth_removed.mp4')
|
| 603 |
+
|
| 604 |
+
input_video = media.read_video(input_video_path)
|
| 605 |
+
target_video = media.read_video(target_video_path)
|
| 606 |
+
mask_video = media.read_video(mask_video_path)
|
| 607 |
+
|
| 608 |
+
# Load depth map if it exists
|
| 609 |
+
depth_video = None
|
| 610 |
+
if os.path.exists(depth_video_path):
|
| 611 |
+
depth_video = media.read_video(depth_video_path)
|
| 612 |
+
|
| 613 |
+
else:
|
| 614 |
+
input_video_path = os.path.join(sample_dir, 'input')
|
| 615 |
+
target_video_path = os.path.join(sample_dir, 'bg')
|
| 616 |
+
mask_video_path = os.path.join(sample_dir, 'trimask')
|
| 617 |
+
|
| 618 |
+
input_video = _read_video_from_dir(input_video_path)
|
| 619 |
+
target_video = _read_video_from_dir(target_video_path)
|
| 620 |
+
mask_video = _read_video_from_dir(mask_video_path)
|
| 621 |
+
|
| 622 |
+
# Initialize depth_video as None for this path
|
| 623 |
+
depth_video = None
|
| 624 |
+
except Exception as e:
|
| 625 |
+
print(f"Error loading video_mask_tuple from {sample_dir}: {e}")
|
| 626 |
+
import traceback
|
| 627 |
+
traceback.print_exc()
|
| 628 |
+
raise
|
| 629 |
+
|
| 630 |
+
mask_video = 255 - mask_video # will be flipped again in when feeding to model
|
| 631 |
+
|
| 632 |
+
if len(mask_video.shape) == 3:
|
| 633 |
+
mask_video = mask_video[..., None]
|
| 634 |
+
if mask_video.shape[-1] == 3:
|
| 635 |
+
mask_video = mask_video[..., :1]
|
| 636 |
+
min_sample_n_frames = min(
|
| 637 |
+
self.video_sample_n_frames,
|
| 638 |
+
int(len(input_video) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 639 |
+
)
|
| 640 |
+
video_length = int(self.video_length_drop_end * len(input_video))
|
| 641 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 642 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 643 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 644 |
+
input_video = input_video[batch_index]
|
| 645 |
+
target_video = target_video[batch_index]
|
| 646 |
+
mask_video = mask_video[batch_index]
|
| 647 |
+
if depth_video is not None:
|
| 648 |
+
depth_video = depth_video[batch_index]
|
| 649 |
+
|
| 650 |
+
resized_inputs = []
|
| 651 |
+
resized_targets = []
|
| 652 |
+
resized_masks = []
|
| 653 |
+
resized_depths = []
|
| 654 |
+
for i in range(len(input_video)):
|
| 655 |
+
resized_input = resize_frame(input_video[i], self.larger_side_of_image_and_video)
|
| 656 |
+
resized_target = resize_frame(target_video[i], self.larger_side_of_image_and_video)
|
| 657 |
+
resized_mask = resize_frame(mask_video[i], self.larger_side_of_image_and_video)
|
| 658 |
+
|
| 659 |
+
# Apply mask quantization based on mode
|
| 660 |
+
if self.ablation_binary_mask:
|
| 661 |
+
# Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127]
|
| 662 |
+
# Map 0 and 63 → 0
|
| 663 |
+
# Map 127 and 255 → 127
|
| 664 |
+
resized_mask = np.where(resized_mask <= 95, 0, resized_mask)
|
| 665 |
+
resized_mask = np.where(resized_mask > 95, 127, resized_mask)
|
| 666 |
+
elif self.use_quadmask:
|
| 667 |
+
# Quadmask mode: preserve 4 values [0, 63, 127, 255]
|
| 668 |
+
# Quantize to nearest quadmask value for robustness
|
| 669 |
+
resized_mask = np.where(resized_mask <= 31, 0, resized_mask)
|
| 670 |
+
resized_mask = np.where(np.logical_and(resized_mask > 31, resized_mask <= 95), 63, resized_mask)
|
| 671 |
+
resized_mask = np.where(np.logical_and(resized_mask > 95, resized_mask <= 191), 127, resized_mask)
|
| 672 |
+
resized_mask = np.where(resized_mask > 191, 255, resized_mask)
|
| 673 |
+
else:
|
| 674 |
+
# Trimask mode: 3 values [0, 127, 255]
|
| 675 |
+
resized_mask = np.where(np.logical_and(resized_mask > 63, resized_mask < 192), 127, resized_mask)
|
| 676 |
+
resized_mask = np.where(resized_mask >= 192, 255, resized_mask)
|
| 677 |
+
resized_mask = np.where(resized_mask <= 63, 0, resized_mask)
|
| 678 |
+
|
| 679 |
+
resized_inputs.append(resized_input)
|
| 680 |
+
resized_targets.append(resized_target)
|
| 681 |
+
resized_masks.append(resized_mask)
|
| 682 |
+
|
| 683 |
+
if depth_video is not None:
|
| 684 |
+
resized_depth = resize_frame(depth_video[i], self.larger_side_of_image_and_video)
|
| 685 |
+
resized_depths.append(resized_depth)
|
| 686 |
+
|
| 687 |
+
input_video = np.array(resized_inputs)
|
| 688 |
+
target_video = np.array(resized_targets)
|
| 689 |
+
mask_video = np.array(resized_masks)
|
| 690 |
+
if depth_video is not None:
|
| 691 |
+
depth_video = np.array(resized_depths)
|
| 692 |
+
|
| 693 |
+
if len(mask_video.shape) == 3:
|
| 694 |
+
mask_video = mask_video[..., None]
|
| 695 |
+
if mask_video.shape[-1] == 3:
|
| 696 |
+
mask_video = mask_video[..., :1]
|
| 697 |
+
if len(mask_video.shape) != 4:
|
| 698 |
+
raise ValueError(f"mask_video shape is {mask_video.shape}.")
|
| 699 |
+
|
| 700 |
+
text = data_info['text']
|
| 701 |
+
print(f"DEBUG DATASET: Converting to tensors (enable_bucket={self.enable_bucket})...")
|
| 702 |
+
if not self.enable_bucket:
|
| 703 |
+
print(f"DEBUG DATASET: Converting input_video to tensor...")
|
| 704 |
+
input_video = torch.from_numpy(input_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 705 |
+
print(f"DEBUG DATASET: Converting target_video to tensor...")
|
| 706 |
+
target_video = torch.from_numpy(target_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 707 |
+
print(f"DEBUG DATASET: Converting mask_video to tensor...")
|
| 708 |
+
mask_video = torch.from_numpy(mask_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 709 |
+
|
| 710 |
+
# Process depth video if available
|
| 711 |
+
if depth_video is not None:
|
| 712 |
+
print(f"DEBUG DATASET: Processing depth_video...")
|
| 713 |
+
# IMPORTANT: Copy depth_video to ensure it's not memory-mapped
|
| 714 |
+
# Memory-mapped files can cause bus errors on GPU transfer
|
| 715 |
+
print(f"DEBUG DATASET: Copying depth_video to ensure not memory-mapped...")
|
| 716 |
+
depth_video = np.array(depth_video, copy=True)
|
| 717 |
+
print(f"DEBUG DATASET: depth_video copied, shape={depth_video.shape}")
|
| 718 |
+
|
| 719 |
+
# Ensure depth has correct shape
|
| 720 |
+
if len(depth_video.shape) == 3:
|
| 721 |
+
depth_video = depth_video[..., None]
|
| 722 |
+
if depth_video.shape[-1] == 3:
|
| 723 |
+
# Convert to grayscale if RGB
|
| 724 |
+
print(f"DEBUG DATASET: Converting depth to grayscale...")
|
| 725 |
+
depth_video = depth_video.mean(axis=-1, keepdims=True)
|
| 726 |
+
# Convert to tensor [F, 1, H, W] and normalize to [0, 1]
|
| 727 |
+
print(f"DEBUG DATASET: Converting depth to tensor...")
|
| 728 |
+
depth_video = torch.from_numpy(depth_video).permute(0, 3, 1, 2).contiguous().float() / 255.
|
| 729 |
+
# Ensure tensor is contiguous and owned
|
| 730 |
+
print(f"DEBUG DATASET: Cloning depth tensor...")
|
| 731 |
+
depth_video = depth_video.clone().contiguous()
|
| 732 |
+
print(f"DEBUG DATASET: depth_video final shape: {depth_video.shape}, is_contiguous: {depth_video.is_contiguous()}")
|
| 733 |
+
|
| 734 |
+
# Apply transforms to each video separately (they expect 3 channels)
|
| 735 |
+
print(f"DEBUG DATASET: Applying video transforms...")
|
| 736 |
+
input_video = self.video_transforms(input_video)
|
| 737 |
+
target_video = self.video_transforms(target_video)
|
| 738 |
+
# Don't normalize mask since it's single channel
|
| 739 |
+
print(f"DEBUG DATASET: Normalizing mask_video...")
|
| 740 |
+
mask_video = mask_video * 2.0 - 1.0 # Scale to [-1, 1] like other channels
|
| 741 |
+
print(f"DEBUG DATASET: All tensors ready (non-bucket mode)")
|
| 742 |
+
|
| 743 |
+
else:
|
| 744 |
+
# For bucket mode, keep as numpy until collate
|
| 745 |
+
# Collate function expects [0, 255] range and will normalize
|
| 746 |
+
print(f"DEBUG DATASET: Bucket mode - keeping as numpy in [0, 255] range...")
|
| 747 |
+
print(f"DEBUG DATASET: All numpy arrays ready (bucket mode)")
|
| 748 |
+
|
| 749 |
+
# Random use no text generation
|
| 750 |
+
if random.random() < self.text_drop_ratio:
|
| 751 |
+
text = ''
|
| 752 |
+
|
| 753 |
+
if self.trimask_zeroout_removal:
|
| 754 |
+
input_video = input_video * np.where(mask_video > 200, 0, 1).astype(input_video.dtype)
|
| 755 |
+
|
| 756 |
+
result = {
|
| 757 |
+
'pixel_values': target_video,
|
| 758 |
+
'input_condition': input_video,
|
| 759 |
+
'mask': mask_video,
|
| 760 |
+
'text': text,
|
| 761 |
+
'data_type': 'video_mask_tuple',
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
+
# Add depth maps if available
|
| 765 |
+
if depth_video is not None:
|
| 766 |
+
result['depth_maps'] = depth_video
|
| 767 |
+
|
| 768 |
+
return result
|
| 769 |
+
|
| 770 |
+
else:
|
| 771 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 772 |
+
if self.data_root is not None:
|
| 773 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 774 |
+
image = Image.open(image_path).convert('RGB')
|
| 775 |
+
if not self.enable_bucket:
|
| 776 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 777 |
+
else:
|
| 778 |
+
image = np.expand_dims(np.array(image), 0)
|
| 779 |
+
if random.random() < self.text_drop_ratio:
|
| 780 |
+
text = ''
|
| 781 |
+
return {
|
| 782 |
+
'pixel_values': image,
|
| 783 |
+
'text': text,
|
| 784 |
+
'data_type': 'image',
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
def __len__(self):
|
| 788 |
+
return self.length
|
| 789 |
+
|
| 790 |
+
def __getitem__(self, idx):
|
| 791 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 792 |
+
data_type = data_info.get('type', 'image')
|
| 793 |
+
while True:
|
| 794 |
+
sample = {}
|
| 795 |
+
try:
|
| 796 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 797 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 798 |
+
if data_type_local != data_type:
|
| 799 |
+
raise ValueError("data_type_local != data_type")
|
| 800 |
+
|
| 801 |
+
sample = self.get_batch(idx)
|
| 802 |
+
sample["idx"] = idx
|
| 803 |
+
|
| 804 |
+
if len(sample) > 0:
|
| 805 |
+
break
|
| 806 |
+
except Exception as e:
|
| 807 |
+
import traceback
|
| 808 |
+
print(f"Error loading sample at index {idx}:")
|
| 809 |
+
print(f"Data info: {self.dataset[idx % len(self.dataset)]}")
|
| 810 |
+
print(f"Error: {e}")
|
| 811 |
+
traceback.print_exc()
|
| 812 |
+
idx = random.randint(0, self.length-1)
|
| 813 |
+
|
| 814 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 815 |
+
if "mask" not in sample:
|
| 816 |
+
mask = get_random_mask_multi(sample["pixel_values"].size())
|
| 817 |
+
sample["mask"] = mask
|
| 818 |
+
else:
|
| 819 |
+
mask = sample["mask"]
|
| 820 |
+
|
| 821 |
+
if "input_condition" in sample:
|
| 822 |
+
mask_pixel_values = sample["input_condition"]
|
| 823 |
+
else:
|
| 824 |
+
mask_pixel_values = sample["pixel_values"]
|
| 825 |
+
mask_pixel_values = mask_pixel_values * (1 - mask) + torch.ones_like(mask_pixel_values) * -1 * mask
|
| 826 |
+
|
| 827 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 828 |
+
|
| 829 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 830 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 831 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 832 |
+
|
| 833 |
+
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
|
| 834 |
+
if (mask == 1).all():
|
| 835 |
+
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
|
| 836 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
| 837 |
+
|
| 838 |
+
return sample
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class ImageVideoControlDataset(Dataset):
|
| 842 |
+
def __init__(
|
| 843 |
+
self,
|
| 844 |
+
ann_path, data_root=None,
|
| 845 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 846 |
+
image_sample_size=512,
|
| 847 |
+
video_repeat=0,
|
| 848 |
+
text_drop_ratio=0.1,
|
| 849 |
+
enable_bucket=False,
|
| 850 |
+
video_length_drop_start=0.0,
|
| 851 |
+
video_length_drop_end=1.0,
|
| 852 |
+
enable_inpaint=False,
|
| 853 |
+
):
|
| 854 |
+
# Loading annotations from files
|
| 855 |
+
print(f"loading annotations from {ann_path} ...")
|
| 856 |
+
if ann_path.endswith('.csv'):
|
| 857 |
+
with open(ann_path, 'r') as csvfile:
|
| 858 |
+
dataset = list(csv.DictReader(csvfile))
|
| 859 |
+
elif ann_path.endswith('.json'):
|
| 860 |
+
dataset = json.load(open(ann_path))
|
| 861 |
+
else:
|
| 862 |
+
raise ValueError(f"Unsupported annotation file format: {ann_path}. Only .csv and .json files are supported.")
|
| 863 |
+
|
| 864 |
+
self.data_root = data_root
|
| 865 |
+
|
| 866 |
+
# It's used to balance num of images and videos.
|
| 867 |
+
self.dataset = []
|
| 868 |
+
for data in dataset:
|
| 869 |
+
if data.get('type', 'image') != 'video':
|
| 870 |
+
self.dataset.append(data)
|
| 871 |
+
if video_repeat > 0:
|
| 872 |
+
for _ in range(video_repeat):
|
| 873 |
+
for data in dataset:
|
| 874 |
+
if data.get('type', 'image') == 'video':
|
| 875 |
+
self.dataset.append(data)
|
| 876 |
+
del dataset
|
| 877 |
+
|
| 878 |
+
self.length = len(self.dataset)
|
| 879 |
+
print(f"data scale: {self.length}")
|
| 880 |
+
# TODO: enable bucket training
|
| 881 |
+
self.enable_bucket = enable_bucket
|
| 882 |
+
self.text_drop_ratio = text_drop_ratio
|
| 883 |
+
self.enable_inpaint = enable_inpaint
|
| 884 |
+
|
| 885 |
+
self.video_length_drop_start = video_length_drop_start
|
| 886 |
+
self.video_length_drop_end = video_length_drop_end
|
| 887 |
+
|
| 888 |
+
# Video params
|
| 889 |
+
self.video_sample_stride = video_sample_stride
|
| 890 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 891 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 892 |
+
self.video_transforms = transforms.Compose(
|
| 893 |
+
[
|
| 894 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 895 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 896 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 897 |
+
]
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
# Image params
|
| 901 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 902 |
+
self.image_transforms = transforms.Compose([
|
| 903 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 904 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 905 |
+
transforms.ToTensor(),
|
| 906 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 907 |
+
])
|
| 908 |
+
|
| 909 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 910 |
+
|
| 911 |
+
def get_batch(self, idx):
|
| 912 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 913 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 914 |
+
|
| 915 |
+
if data_info.get('type', 'image')=='video':
|
| 916 |
+
if self.data_root is None:
|
| 917 |
+
video_dir = video_id
|
| 918 |
+
else:
|
| 919 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 920 |
+
|
| 921 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 922 |
+
min_sample_n_frames = min(
|
| 923 |
+
self.video_sample_n_frames,
|
| 924 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 925 |
+
)
|
| 926 |
+
if min_sample_n_frames == 0:
|
| 927 |
+
raise ValueError(f"No Frames in video.")
|
| 928 |
+
|
| 929 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 930 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 931 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 932 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 933 |
+
|
| 934 |
+
try:
|
| 935 |
+
sample_args = (video_reader, batch_index)
|
| 936 |
+
pixel_values = func_timeout(
|
| 937 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 938 |
+
)
|
| 939 |
+
resized_frames = []
|
| 940 |
+
for i in range(len(pixel_values)):
|
| 941 |
+
frame = pixel_values[i]
|
| 942 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 943 |
+
resized_frames.append(resized_frame)
|
| 944 |
+
pixel_values = np.array(resized_frames)
|
| 945 |
+
except FunctionTimedOut:
|
| 946 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 947 |
+
except Exception as e:
|
| 948 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 949 |
+
|
| 950 |
+
if not self.enable_bucket:
|
| 951 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 952 |
+
pixel_values = pixel_values / 255.
|
| 953 |
+
del video_reader
|
| 954 |
+
else:
|
| 955 |
+
pixel_values = pixel_values
|
| 956 |
+
|
| 957 |
+
if not self.enable_bucket:
|
| 958 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 959 |
+
|
| 960 |
+
# Random use no text generation
|
| 961 |
+
if random.random() < self.text_drop_ratio:
|
| 962 |
+
text = ''
|
| 963 |
+
|
| 964 |
+
control_video_id = data_info['control_file_path']
|
| 965 |
+
|
| 966 |
+
if self.data_root is None:
|
| 967 |
+
control_video_id = control_video_id
|
| 968 |
+
else:
|
| 969 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 970 |
+
|
| 971 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 972 |
+
try:
|
| 973 |
+
sample_args = (control_video_reader, batch_index)
|
| 974 |
+
control_pixel_values = func_timeout(
|
| 975 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 976 |
+
)
|
| 977 |
+
resized_frames = []
|
| 978 |
+
for i in range(len(control_pixel_values)):
|
| 979 |
+
frame = control_pixel_values[i]
|
| 980 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 981 |
+
resized_frames.append(resized_frame)
|
| 982 |
+
control_pixel_values = np.array(resized_frames)
|
| 983 |
+
except FunctionTimedOut:
|
| 984 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 985 |
+
except Exception as e:
|
| 986 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 987 |
+
|
| 988 |
+
if not self.enable_bucket:
|
| 989 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 990 |
+
control_pixel_values = control_pixel_values / 255.
|
| 991 |
+
del control_video_reader
|
| 992 |
+
else:
|
| 993 |
+
control_pixel_values = control_pixel_values
|
| 994 |
+
|
| 995 |
+
if not self.enable_bucket:
|
| 996 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 997 |
+
return pixel_values, control_pixel_values, text, "video"
|
| 998 |
+
else:
|
| 999 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 1000 |
+
if self.data_root is not None:
|
| 1001 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 1002 |
+
image = Image.open(image_path).convert('RGB')
|
| 1003 |
+
if not self.enable_bucket:
|
| 1004 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 1005 |
+
else:
|
| 1006 |
+
image = np.expand_dims(np.array(image), 0)
|
| 1007 |
+
|
| 1008 |
+
if random.random() < self.text_drop_ratio:
|
| 1009 |
+
text = ''
|
| 1010 |
+
|
| 1011 |
+
control_image_id = data_info['control_file_path']
|
| 1012 |
+
|
| 1013 |
+
if self.data_root is None:
|
| 1014 |
+
control_image_id = control_image_id
|
| 1015 |
+
else:
|
| 1016 |
+
control_image_id = os.path.join(self.data_root, control_image_id)
|
| 1017 |
+
|
| 1018 |
+
control_image = Image.open(control_image_id).convert('RGB')
|
| 1019 |
+
if not self.enable_bucket:
|
| 1020 |
+
control_image = self.image_transforms(control_image).unsqueeze(0)
|
| 1021 |
+
else:
|
| 1022 |
+
control_image = np.expand_dims(np.array(control_image), 0)
|
| 1023 |
+
return image, control_image, text, 'image'
|
| 1024 |
+
|
| 1025 |
+
def __len__(self):
|
| 1026 |
+
return self.length
|
| 1027 |
+
|
| 1028 |
+
def __getitem__(self, idx):
|
| 1029 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 1030 |
+
data_type = data_info.get('type', 'image')
|
| 1031 |
+
while True:
|
| 1032 |
+
sample = {}
|
| 1033 |
+
try:
|
| 1034 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 1035 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 1036 |
+
if data_type_local != data_type:
|
| 1037 |
+
raise ValueError("data_type_local != data_type")
|
| 1038 |
+
|
| 1039 |
+
pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
|
| 1040 |
+
sample["pixel_values"] = pixel_values
|
| 1041 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 1042 |
+
sample["text"] = name
|
| 1043 |
+
sample["data_type"] = data_type
|
| 1044 |
+
sample["idx"] = idx
|
| 1045 |
+
|
| 1046 |
+
if len(sample) > 0:
|
| 1047 |
+
break
|
| 1048 |
+
except Exception as e:
|
| 1049 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 1050 |
+
idx = random.randint(0, self.length-1)
|
| 1051 |
+
|
| 1052 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 1053 |
+
mask = get_random_mask(pixel_values.size())
|
| 1054 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 1055 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 1056 |
+
sample["mask"] = mask
|
| 1057 |
+
|
| 1058 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 1059 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 1060 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 1061 |
+
|
| 1062 |
+
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
|
| 1063 |
+
if (mask == 1).all():
|
| 1064 |
+
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
|
| 1065 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
| 1066 |
+
|
| 1067 |
+
return sample
|
videox_fun/data/dataset_image_video_warped.py
ADDED
|
@@ -0,0 +1,1092 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import glob
|
| 7 |
+
import random
|
| 8 |
+
from threading import Thread
|
| 9 |
+
import mediapy as media
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import albumentations
|
| 13 |
+
import cv2
|
| 14 |
+
import gc
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from scipy.special import binom
|
| 19 |
+
|
| 20 |
+
from func_timeout import func_timeout, FunctionTimedOut
|
| 21 |
+
from decord import VideoReader
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 24 |
+
from torch.utils.data.dataset import Dataset
|
| 25 |
+
from contextlib import contextmanager
|
| 26 |
+
|
| 27 |
+
VIDEO_READER_TIMEOUT = 20
|
| 28 |
+
|
| 29 |
+
bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)
|
| 30 |
+
|
| 31 |
+
# codes from https://stackoverflow.com/questions/50731785/create-random-shape-contour-using-matplotlib
|
| 32 |
+
def bezier(points, num=200):
|
| 33 |
+
N = len(points)
|
| 34 |
+
t = np.linspace(0, 1, num=num)
|
| 35 |
+
curve = np.zeros((num, 2))
|
| 36 |
+
for i in range(N):
|
| 37 |
+
curve += np.outer(bernstein(N - 1, i, t), points[i])
|
| 38 |
+
return curve
|
| 39 |
+
|
| 40 |
+
class Segment():
|
| 41 |
+
def __init__(self, p1, p2, angle1, angle2, **kw):
|
| 42 |
+
self.p1 = p1
|
| 43 |
+
self.p2 = p2
|
| 44 |
+
self.angle1 = angle1
|
| 45 |
+
self.angle2 = angle2
|
| 46 |
+
self.numpoints = kw.get("numpoints", 100)
|
| 47 |
+
r = kw.get("r", 0.3)
|
| 48 |
+
d = np.sqrt(np.sum((self.p2-self.p1)**2))
|
| 49 |
+
self.r = r*d
|
| 50 |
+
self.p = np.zeros((4,2))
|
| 51 |
+
self.p[0,:] = self.p1[:]
|
| 52 |
+
self.p[3,:] = self.p2[:]
|
| 53 |
+
self.calc_intermediate_points(self.r)
|
| 54 |
+
|
| 55 |
+
def calc_intermediate_points(self,r):
|
| 56 |
+
self.p[1,:] = self.p1 + np.array(
|
| 57 |
+
[self.r*np.cos(self.angle1), self.r*np.sin(self.angle1)])
|
| 58 |
+
self.p[2,:] = self.p2 + np.array(
|
| 59 |
+
[self.r*np.cos(self.angle2+np.pi), self.r*np.sin(self.angle2+np.pi)])
|
| 60 |
+
self.curve = bezier(self.p,self.numpoints)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_curve(points, **kw):
|
| 64 |
+
segments = []
|
| 65 |
+
for i in range(len(points)-1):
|
| 66 |
+
seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
|
| 67 |
+
segments.append(seg)
|
| 68 |
+
curve = np.concatenate([s.curve for s in segments])
|
| 69 |
+
return segments, curve
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def ccw_sort(p):
|
| 73 |
+
d = p-np.mean(p,axis=0)
|
| 74 |
+
s = np.arctan2(d[:,0], d[:,1])
|
| 75 |
+
return p[np.argsort(s),:]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_bezier_curve(a, rad=0.2, edgy=0):
|
| 79 |
+
""" given an array of points *a*, create a curve through
|
| 80 |
+
those points.
|
| 81 |
+
*rad* is a number between 0 and 1 to steer the distance of
|
| 82 |
+
control points.
|
| 83 |
+
*edgy* is a parameter which controls how "edgy" the curve is,
|
| 84 |
+
edgy=0 is smoothest."""
|
| 85 |
+
p = np.arctan(edgy)/np.pi+.5
|
| 86 |
+
a = ccw_sort(a)
|
| 87 |
+
a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
|
| 88 |
+
d = np.diff(a, axis=0)
|
| 89 |
+
ang = np.arctan2(d[:,1],d[:,0])
|
| 90 |
+
f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
|
| 91 |
+
ang = f(ang)
|
| 92 |
+
ang1 = ang
|
| 93 |
+
ang2 = np.roll(ang,1)
|
| 94 |
+
ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
|
| 95 |
+
ang = np.append(ang, [ang[0]])
|
| 96 |
+
a = np.append(a, np.atleast_2d(ang).T, axis=1)
|
| 97 |
+
s, c = get_curve(a, r=rad, method="var")
|
| 98 |
+
x,y = c.T
|
| 99 |
+
return x,y, a
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_random_points(n=5, scale=0.8, mindst=None, rec=0):
|
| 103 |
+
""" create n random points in the unit square, which are *mindst*
|
| 104 |
+
apart, then scale them."""
|
| 105 |
+
mindst = mindst or .7/n
|
| 106 |
+
a = np.random.rand(n,2)
|
| 107 |
+
d = np.sqrt(np.sum(np.diff(ccw_sort(a), axis=0), axis=1)**2)
|
| 108 |
+
if np.all(d >= mindst) or rec>=200:
|
| 109 |
+
return a*scale
|
| 110 |
+
else:
|
| 111 |
+
return get_random_points(n=n, scale=scale, mindst=mindst, rec=rec+1)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def fill_mask(shape, x, y, fill_val=255):
|
| 115 |
+
_, _, h, w = shape
|
| 116 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 117 |
+
mask = cv2.fillPoly(mask, [np.array([x, y], np.int32).T], fill_val)
|
| 118 |
+
return mask
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def random_shift(x, y, scale_range = [0.2, 0.7], trans_perturb_range=[-0.2, 0.2]):
|
| 122 |
+
w_scale = np.random.uniform(scale_range[0], scale_range[1])
|
| 123 |
+
h_scale = np.random.uniform(scale_range[0], scale_range[1])
|
| 124 |
+
x_trans = np.random.uniform(0., 1. - w_scale)
|
| 125 |
+
y_trans = np.random.uniform(0., 1. - h_scale)
|
| 126 |
+
x_shifted = x * w_scale + x_trans + np.random.uniform(trans_perturb_range[0], trans_perturb_range[1])
|
| 127 |
+
y_shifted = y * h_scale + y_trans + np.random.uniform(trans_perturb_range[0], trans_perturb_range[1])
|
| 128 |
+
return x_shifted, y_shifted
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_random_shape_mask(
|
| 132 |
+
shape, n_pts_range=[3, 10], rad_range=[0.0, 1.0], edgy_range=[0.0, 0.1], n_keyframes_range=[2, 25],
|
| 133 |
+
random_drop_range=[0.0, 0.2],
|
| 134 |
+
):
|
| 135 |
+
f, _, h, w = shape
|
| 136 |
+
|
| 137 |
+
n_pts = np.random.randint(n_pts_range[0], n_pts_range[1])
|
| 138 |
+
n_keyframes = np.random.randint(n_keyframes_range[0], n_keyframes_range[1])
|
| 139 |
+
keyframe_interval = f // (n_keyframes - 1)
|
| 140 |
+
keyframe_indices = list(range(0, f, keyframe_interval))
|
| 141 |
+
if len(keyframe_indices) == n_keyframes:
|
| 142 |
+
keyframe_indices[-1] = f - 1
|
| 143 |
+
else:
|
| 144 |
+
keyframe_indices.append(f - 1)
|
| 145 |
+
x_all_frames, y_all_frames = [], []
|
| 146 |
+
for i, keyframe_index in enumerate(keyframe_indices):
|
| 147 |
+
rad = np.random.uniform(rad_range[0], rad_range[1])
|
| 148 |
+
edgy = np.random.uniform(edgy_range[0], edgy_range[1])
|
| 149 |
+
x_kf, y_kf, _ = get_bezier_curve(get_random_points(n=n_pts), rad=rad, edgy=edgy)
|
| 150 |
+
x_kf, y_kf = random_shift(x_kf, y_kf)
|
| 151 |
+
if i == 0:
|
| 152 |
+
x_all_frames.append(x_kf[None])
|
| 153 |
+
y_all_frames.append(y_kf[None])
|
| 154 |
+
else:
|
| 155 |
+
x_interval = np.linspace(x_all_frames[-1][-1], x_kf, keyframe_index - keyframe_indices[i - 1] + 1)
|
| 156 |
+
y_interval = np.linspace(y_all_frames[-1][-1], y_kf, keyframe_index - keyframe_indices[i - 1] + 1)
|
| 157 |
+
x_all_frames.append(x_interval[1:])
|
| 158 |
+
y_all_frames.append(y_interval[1:])
|
| 159 |
+
x_all_frames = np.concatenate(x_all_frames, axis=0)
|
| 160 |
+
y_all_frames = np.concatenate(y_all_frames, axis=0)
|
| 161 |
+
|
| 162 |
+
masks = []
|
| 163 |
+
for x, y in zip(x_all_frames, y_all_frames):
|
| 164 |
+
x = np.round(x * w).astype(np.int32)
|
| 165 |
+
y = np.round(y * h).astype(np.int32)
|
| 166 |
+
mask = fill_mask(shape, x, y)
|
| 167 |
+
masks.append(mask)
|
| 168 |
+
masks = np.stack(masks, axis=0).astype(float) / 255.
|
| 169 |
+
|
| 170 |
+
n_frames_random_drop = int(np.random.uniform(random_drop_range[0], random_drop_range[1]) * f)
|
| 171 |
+
drop_index = np.random.randint(0, f - n_frames_random_drop)
|
| 172 |
+
masks[drop_index:drop_index + n_frames_random_drop] = 0
|
| 173 |
+
|
| 174 |
+
return masks # (f, h, w), <float>[0, 1]
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_random_mask(shape, mask_type_probs=[0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.8]):
|
| 178 |
+
f, c, h, w = shape
|
| 179 |
+
|
| 180 |
+
if f != 1:
|
| 181 |
+
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], p=mask_type_probs)
|
| 182 |
+
else:
|
| 183 |
+
mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
|
| 184 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
| 185 |
+
|
| 186 |
+
if mask_index == 0:
|
| 187 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 188 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 189 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()
|
| 190 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()
|
| 191 |
+
|
| 192 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 193 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 194 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 195 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 196 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
| 197 |
+
elif mask_index == 1:
|
| 198 |
+
mask[:, :, :, :] = 1
|
| 199 |
+
elif mask_index == 2:
|
| 200 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 201 |
+
mask[mask_frame_index:, :, :, :] = 1
|
| 202 |
+
elif mask_index == 3:
|
| 203 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 204 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
| 205 |
+
elif mask_index == 4:
|
| 206 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 207 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 208 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()
|
| 209 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()
|
| 210 |
+
|
| 211 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 212 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 213 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 214 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 215 |
+
|
| 216 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
| 217 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
| 218 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
| 219 |
+
elif mask_index == 5:
|
| 220 |
+
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
|
| 221 |
+
elif mask_index == 6:
|
| 222 |
+
num_frames_to_mask = random.randint(1, max(f // 2, 1))
|
| 223 |
+
frames_to_mask = random.sample(range(f), num_frames_to_mask)
|
| 224 |
+
|
| 225 |
+
for i in frames_to_mask:
|
| 226 |
+
block_height = random.randint(1, h // 4)
|
| 227 |
+
block_width = random.randint(1, w // 4)
|
| 228 |
+
top_left_y = random.randint(0, h - block_height)
|
| 229 |
+
top_left_x = random.randint(0, w - block_width)
|
| 230 |
+
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
|
| 231 |
+
elif mask_index == 7:
|
| 232 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 233 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 234 |
+
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item()
|
| 235 |
+
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
| 236 |
+
|
| 237 |
+
for i in range(h):
|
| 238 |
+
for j in range(w):
|
| 239 |
+
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
|
| 240 |
+
mask[:, :, i, j] = 1
|
| 241 |
+
elif mask_index == 8:
|
| 242 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 243 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 244 |
+
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
| 245 |
+
for i in range(h):
|
| 246 |
+
for j in range(w):
|
| 247 |
+
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
|
| 248 |
+
mask[:, :, i, j] = 1
|
| 249 |
+
elif mask_index == 9:
|
| 250 |
+
for idx in range(f):
|
| 251 |
+
if np.random.rand() > 0.5:
|
| 252 |
+
mask[idx, :, :, :] = 1
|
| 253 |
+
else:
|
| 254 |
+
num_objs = np.random.randint(1, 4)
|
| 255 |
+
mask_npy = get_random_shape_mask(shape)
|
| 256 |
+
for i in range(num_objs - 1):
|
| 257 |
+
mask_npy += get_random_shape_mask(shape).clip(0, 1)
|
| 258 |
+
|
| 259 |
+
mask = torch.from_numpy(mask_npy).unsqueeze(1)
|
| 260 |
+
|
| 261 |
+
return mask.float()
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def get_random_mask_multi(shape, mask_type_probs, range_num_masks=[1, 7]):
|
| 265 |
+
num_masks = np.random.randint(range_num_masks[0], range_num_masks[1])
|
| 266 |
+
masks = None
|
| 267 |
+
for _ in range(num_masks):
|
| 268 |
+
mask = get_random_mask(shape, mask_type_probs)
|
| 269 |
+
if masks is None:
|
| 270 |
+
masks = mask
|
| 271 |
+
else:
|
| 272 |
+
masks = (masks + mask).clip(0, 1)
|
| 273 |
+
return masks
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class ImageVideoSampler(BatchSampler):
|
| 277 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
sampler (Sampler): Base sampler.
|
| 281 |
+
dataset (Dataset): Dataset providing data information.
|
| 282 |
+
batch_size (int): Size of mini-batch.
|
| 283 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 284 |
+
its size would be less than ``batch_size``.
|
| 285 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
def __init__(self,
|
| 289 |
+
sampler: Sampler,
|
| 290 |
+
dataset: Dataset,
|
| 291 |
+
batch_size: int,
|
| 292 |
+
drop_last: bool = False
|
| 293 |
+
) -> None:
|
| 294 |
+
if not isinstance(sampler, Sampler):
|
| 295 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 296 |
+
f'but got {sampler}')
|
| 297 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 298 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 299 |
+
f'but got batch_size={batch_size}')
|
| 300 |
+
self.sampler = sampler
|
| 301 |
+
self.dataset = dataset
|
| 302 |
+
self.batch_size = batch_size
|
| 303 |
+
self.drop_last = drop_last
|
| 304 |
+
|
| 305 |
+
# buckets for each aspect ratio
|
| 306 |
+
self.bucket = {'image':[], 'video':[], 'video_mask_tuple':[]}
|
| 307 |
+
|
| 308 |
+
def __iter__(self):
|
| 309 |
+
for idx in self.sampler:
|
| 310 |
+
content_type = self.dataset.dataset[idx].get('type', 'image')
|
| 311 |
+
self.bucket[content_type].append(idx)
|
| 312 |
+
|
| 313 |
+
# yield a batch of indices in the same aspect ratio group
|
| 314 |
+
if len(self.bucket['video']) == self.batch_size:
|
| 315 |
+
bucket = self.bucket['video']
|
| 316 |
+
yield bucket[:]
|
| 317 |
+
del bucket[:]
|
| 318 |
+
elif len(self.bucket['video_mask_tuple']) == self.batch_size:
|
| 319 |
+
bucket = self.bucket['video_mask_tuple']
|
| 320 |
+
yield bucket[:]
|
| 321 |
+
del bucket[:]
|
| 322 |
+
elif len(self.bucket['image']) == self.batch_size:
|
| 323 |
+
bucket = self.bucket['image']
|
| 324 |
+
yield bucket[:]
|
| 325 |
+
del bucket[:]
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@contextmanager
|
| 329 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
| 330 |
+
vr = VideoReader(*args, **kwargs)
|
| 331 |
+
try:
|
| 332 |
+
yield vr
|
| 333 |
+
finally:
|
| 334 |
+
del vr
|
| 335 |
+
gc.collect()
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def get_video_reader_batch(video_reader, batch_index):
|
| 339 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
| 340 |
+
return frames
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _read_video_from_dir(video_dir):
|
| 344 |
+
frames = []
|
| 345 |
+
frame_paths = sorted(list(glob.glob(os.path.join(video_dir, '*.png'))))
|
| 346 |
+
|
| 347 |
+
if not frame_paths:
|
| 348 |
+
raise ValueError(f"No PNG files found in directory: {video_dir}")
|
| 349 |
+
|
| 350 |
+
for frame_path in frame_paths:
|
| 351 |
+
frame = media.read_image(frame_path)
|
| 352 |
+
frames.append(frame)
|
| 353 |
+
|
| 354 |
+
if not frames:
|
| 355 |
+
raise ValueError(f"Failed to read any frames from directory: {video_dir}")
|
| 356 |
+
|
| 357 |
+
return np.stack(frames, axis=0)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def resize_frame(frame, target_short_side):
|
| 361 |
+
h, w, _ = frame.shape
|
| 362 |
+
if h < w:
|
| 363 |
+
if target_short_side > h:
|
| 364 |
+
return frame
|
| 365 |
+
new_h = target_short_side
|
| 366 |
+
new_w = int(target_short_side * w / h)
|
| 367 |
+
else:
|
| 368 |
+
if target_short_side > w:
|
| 369 |
+
return frame
|
| 370 |
+
new_w = target_short_side
|
| 371 |
+
new_h = int(target_short_side * h / w)
|
| 372 |
+
|
| 373 |
+
resized_frame = cv2.resize(frame, (new_w, new_h))
|
| 374 |
+
return resized_frame
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class ImageVideoDataset(Dataset):
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
ann_path, data_root=None,
|
| 381 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 382 |
+
image_sample_size=512,
|
| 383 |
+
video_repeat=0,
|
| 384 |
+
text_drop_ratio=0.1,
|
| 385 |
+
enable_bucket=False,
|
| 386 |
+
video_length_drop_start=0.0,
|
| 387 |
+
video_length_drop_end=1.0,
|
| 388 |
+
enable_inpaint=False,
|
| 389 |
+
trimask_zeroout_removal=False,
|
| 390 |
+
use_quadmask=False,
|
| 391 |
+
ablation_binary_mask=False,
|
| 392 |
+
):
|
| 393 |
+
# Loading annotations from files
|
| 394 |
+
print(f"loading annotations from {ann_path} ...")
|
| 395 |
+
if ann_path.endswith('.csv'):
|
| 396 |
+
with open(ann_path, 'r') as csvfile:
|
| 397 |
+
dataset = list(csv.DictReader(csvfile))
|
| 398 |
+
elif ann_path.endswith('.json'):
|
| 399 |
+
dataset = json.load(open(ann_path))
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError(f"Unsupported annotation file format: {ann_path}. Only .csv and .json files are supported.")
|
| 402 |
+
|
| 403 |
+
self.data_root = data_root
|
| 404 |
+
|
| 405 |
+
# It's used to balance num of images and videos.
|
| 406 |
+
self.dataset = []
|
| 407 |
+
for data in dataset:
|
| 408 |
+
if data.get('type', 'image') != 'video':
|
| 409 |
+
self.dataset.append(data)
|
| 410 |
+
if video_repeat > 0:
|
| 411 |
+
for _ in range(video_repeat):
|
| 412 |
+
for data in dataset:
|
| 413 |
+
if data.get('type', 'image') == 'video':
|
| 414 |
+
self.dataset.append(data)
|
| 415 |
+
del dataset
|
| 416 |
+
|
| 417 |
+
self.length = len(self.dataset)
|
| 418 |
+
print(f"data scale: {self.length}")
|
| 419 |
+
# TODO: enable bucket training
|
| 420 |
+
self.enable_bucket = enable_bucket
|
| 421 |
+
self.text_drop_ratio = text_drop_ratio
|
| 422 |
+
self.enable_inpaint = enable_inpaint
|
| 423 |
+
self.trimask_zeroout_removal = trimask_zeroout_removal
|
| 424 |
+
self.use_quadmask = use_quadmask
|
| 425 |
+
self.ablation_binary_mask = ablation_binary_mask
|
| 426 |
+
|
| 427 |
+
self.video_length_drop_start = video_length_drop_start
|
| 428 |
+
self.video_length_drop_end = video_length_drop_end
|
| 429 |
+
|
| 430 |
+
if self.use_quadmask:
|
| 431 |
+
print(f"[QUADMASK MODE] Using 4-value quadmask: [0, 63, 127, 255]")
|
| 432 |
+
if self.ablation_binary_mask:
|
| 433 |
+
print(f"[ABLATION BINARY MASK] Remapping quadmask to binary: [0,63]→0, [127,255]→127")
|
| 434 |
+
else:
|
| 435 |
+
print(f"[TRIMASK MODE] Using 3-value trimask: [0, 127, 255]")
|
| 436 |
+
|
| 437 |
+
# Video params
|
| 438 |
+
self.video_sample_stride = video_sample_stride
|
| 439 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 440 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 441 |
+
self.video_transforms = transforms.Compose(
|
| 442 |
+
[
|
| 443 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 444 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 445 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 446 |
+
]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# Image params
|
| 450 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 451 |
+
self.image_transforms = transforms.Compose([
|
| 452 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 453 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 454 |
+
transforms.ToTensor(),
|
| 455 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 456 |
+
])
|
| 457 |
+
|
| 458 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 459 |
+
|
| 460 |
+
def get_batch(self, idx):
|
| 461 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 462 |
+
|
| 463 |
+
if data_info.get('type', 'image') == 'video' and data_info.get('mask_path', None) is None:
|
| 464 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 465 |
+
|
| 466 |
+
if self.data_root is None:
|
| 467 |
+
video_dir = video_id
|
| 468 |
+
else:
|
| 469 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 470 |
+
|
| 471 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 472 |
+
min_sample_n_frames = min(
|
| 473 |
+
self.video_sample_n_frames,
|
| 474 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 475 |
+
)
|
| 476 |
+
if min_sample_n_frames == 0:
|
| 477 |
+
raise ValueError(f"No Frames in video.")
|
| 478 |
+
|
| 479 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 480 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 481 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 482 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 483 |
+
|
| 484 |
+
try:
|
| 485 |
+
sample_args = (video_reader, batch_index)
|
| 486 |
+
pixel_values = func_timeout(
|
| 487 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 488 |
+
)
|
| 489 |
+
resized_frames = []
|
| 490 |
+
for i in range(len(pixel_values)):
|
| 491 |
+
frame = pixel_values[i]
|
| 492 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 493 |
+
resized_frames.append(resized_frame)
|
| 494 |
+
pixel_values = np.array(resized_frames)
|
| 495 |
+
except FunctionTimedOut:
|
| 496 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 497 |
+
except Exception as e:
|
| 498 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 499 |
+
|
| 500 |
+
if not self.enable_bucket:
|
| 501 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 502 |
+
pixel_values = pixel_values / 255.
|
| 503 |
+
del video_reader
|
| 504 |
+
else:
|
| 505 |
+
pixel_values = pixel_values
|
| 506 |
+
|
| 507 |
+
if not self.enable_bucket:
|
| 508 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 509 |
+
|
| 510 |
+
# Random use no text generation
|
| 511 |
+
if random.random() < self.text_drop_ratio:
|
| 512 |
+
text = ''
|
| 513 |
+
return {
|
| 514 |
+
'pixel_values': pixel_values,
|
| 515 |
+
'text': text,
|
| 516 |
+
'data_type': 'video',
|
| 517 |
+
}
|
| 518 |
+
elif data_info.get('type', 'image') == 'video' and data_info.get('mask_path', None) is not None: # video with known mask
|
| 519 |
+
video_path, text = data_info['file_path'], data_info['text']
|
| 520 |
+
mask_video_path = video_path[:-4] + '_mask.mp4'
|
| 521 |
+
with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
|
| 522 |
+
min_sample_n_frames = min(
|
| 523 |
+
self.video_sample_n_frames,
|
| 524 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 525 |
+
)
|
| 526 |
+
if min_sample_n_frames == 0:
|
| 527 |
+
raise ValueError(f"No Frames in video.")
|
| 528 |
+
|
| 529 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 530 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 531 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 532 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 533 |
+
|
| 534 |
+
try:
|
| 535 |
+
sample_args = (video_reader, batch_index)
|
| 536 |
+
pixel_values = func_timeout(
|
| 537 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 538 |
+
)
|
| 539 |
+
resized_frames = []
|
| 540 |
+
for i in range(len(pixel_values)):
|
| 541 |
+
frame = pixel_values[i]
|
| 542 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 543 |
+
resized_frames.append(resized_frame)
|
| 544 |
+
input_video = np.array(resized_frames)
|
| 545 |
+
except FunctionTimedOut:
|
| 546 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 547 |
+
except Exception as e:
|
| 548 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 549 |
+
|
| 550 |
+
with VideoReader_contextmanager(mask_video_path, num_threads=2) as video_reader:
|
| 551 |
+
try:
|
| 552 |
+
sample_args = (video_reader, batch_index)
|
| 553 |
+
mask_values = func_timeout(
|
| 554 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 555 |
+
)
|
| 556 |
+
resized_frames = []
|
| 557 |
+
for i in range(len(mask_values)):
|
| 558 |
+
frame = mask_values[i]
|
| 559 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 560 |
+
resized_frames.append(resized_frame)
|
| 561 |
+
mask_video = np.array(resized_frames)
|
| 562 |
+
except FunctionTimedOut:
|
| 563 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 564 |
+
except Exception as e:
|
| 565 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 566 |
+
|
| 567 |
+
if len(mask_video.shape) == 3:
|
| 568 |
+
mask_video = mask_video[..., None]
|
| 569 |
+
if mask_video.shape[-1] == 3:
|
| 570 |
+
mask_video = mask_video[..., :1]
|
| 571 |
+
if len(mask_video.shape) != 4:
|
| 572 |
+
raise ValueError(f"mask_video shape is {mask_video.shape}.")
|
| 573 |
+
|
| 574 |
+
text = data_info['text']
|
| 575 |
+
if not self.enable_bucket:
|
| 576 |
+
input_video = torch.from_numpy(input_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 577 |
+
mask_video = torch.from_numpy(mask_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 578 |
+
|
| 579 |
+
pixel_values = torch.cat([input_video, mask_video], dim=1)
|
| 580 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 581 |
+
input_video = pixel_values[:, :3]
|
| 582 |
+
mask_video = pixel_values[:, 3:]
|
| 583 |
+
|
| 584 |
+
# Random use no text generation
|
| 585 |
+
if random.random() < self.text_drop_ratio:
|
| 586 |
+
text = ''
|
| 587 |
+
|
| 588 |
+
return {
|
| 589 |
+
'pixel_values': input_video,
|
| 590 |
+
'mask': mask_video,
|
| 591 |
+
'text': text,
|
| 592 |
+
'data_type': 'video',
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
elif data_info.get('type', 'image') == 'video_mask_tuple': # object effect removal
|
| 596 |
+
sample_dir = data_info['file_path'] if self.data_root is None else os.path.join(self.data_root, data_info['file_path'])
|
| 597 |
+
try:
|
| 598 |
+
if os.path.exists(os.path.join(sample_dir, 'rgb_full.mp4')):
|
| 599 |
+
input_video_path = os.path.join(sample_dir, 'rgb_full.mp4')
|
| 600 |
+
target_video_path = os.path.join(sample_dir, 'rgb_removed.mp4')
|
| 601 |
+
mask_video_path = os.path.join(sample_dir, 'mask.mp4')
|
| 602 |
+
depth_video_path = os.path.join(sample_dir, 'depth_removed.mp4')
|
| 603 |
+
|
| 604 |
+
input_video = media.read_video(input_video_path)
|
| 605 |
+
target_video = media.read_video(target_video_path)
|
| 606 |
+
mask_video = media.read_video(mask_video_path)
|
| 607 |
+
|
| 608 |
+
# Load depth map if it exists
|
| 609 |
+
depth_video = None
|
| 610 |
+
if os.path.exists(depth_video_path):
|
| 611 |
+
depth_video = media.read_video(depth_video_path)
|
| 612 |
+
|
| 613 |
+
else:
|
| 614 |
+
input_video_path = os.path.join(sample_dir, 'input')
|
| 615 |
+
target_video_path = os.path.join(sample_dir, 'bg')
|
| 616 |
+
mask_video_path = os.path.join(sample_dir, 'trimask')
|
| 617 |
+
|
| 618 |
+
input_video = _read_video_from_dir(input_video_path)
|
| 619 |
+
target_video = _read_video_from_dir(target_video_path)
|
| 620 |
+
mask_video = _read_video_from_dir(mask_video_path)
|
| 621 |
+
|
| 622 |
+
# Initialize depth_video as None for this path
|
| 623 |
+
depth_video = None
|
| 624 |
+
except Exception as e:
|
| 625 |
+
print(f"Error loading video_mask_tuple from {sample_dir}: {e}")
|
| 626 |
+
import traceback
|
| 627 |
+
traceback.print_exc()
|
| 628 |
+
raise
|
| 629 |
+
|
| 630 |
+
mask_video = 255 - mask_video # will be flipped again in when feeding to model
|
| 631 |
+
|
| 632 |
+
if len(mask_video.shape) == 3:
|
| 633 |
+
mask_video = mask_video[..., None]
|
| 634 |
+
if mask_video.shape[-1] == 3:
|
| 635 |
+
mask_video = mask_video[..., :1]
|
| 636 |
+
min_sample_n_frames = min(
|
| 637 |
+
self.video_sample_n_frames,
|
| 638 |
+
int(len(input_video) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 639 |
+
)
|
| 640 |
+
video_length = int(self.video_length_drop_end * len(input_video))
|
| 641 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 642 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 643 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 644 |
+
input_video = input_video[batch_index]
|
| 645 |
+
target_video = target_video[batch_index]
|
| 646 |
+
mask_video = mask_video[batch_index]
|
| 647 |
+
if depth_video is not None:
|
| 648 |
+
depth_video = depth_video[batch_index]
|
| 649 |
+
|
| 650 |
+
resized_inputs = []
|
| 651 |
+
resized_targets = []
|
| 652 |
+
resized_masks = []
|
| 653 |
+
resized_depths = []
|
| 654 |
+
for i in range(len(input_video)):
|
| 655 |
+
resized_input = resize_frame(input_video[i], self.larger_side_of_image_and_video)
|
| 656 |
+
resized_target = resize_frame(target_video[i], self.larger_side_of_image_and_video)
|
| 657 |
+
resized_mask = resize_frame(mask_video[i], self.larger_side_of_image_and_video)
|
| 658 |
+
|
| 659 |
+
# Apply mask quantization based on mode
|
| 660 |
+
if self.ablation_binary_mask:
|
| 661 |
+
# Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127]
|
| 662 |
+
# Map 0 and 63 → 0
|
| 663 |
+
# Map 127 and 255 → 127
|
| 664 |
+
resized_mask = np.where(resized_mask <= 95, 0, resized_mask)
|
| 665 |
+
resized_mask = np.where(resized_mask > 95, 127, resized_mask)
|
| 666 |
+
elif self.use_quadmask:
|
| 667 |
+
# Quadmask mode: preserve 4 values [0, 63, 127, 255]
|
| 668 |
+
# Quantize to nearest quadmask value for robustness
|
| 669 |
+
resized_mask = np.where(resized_mask <= 31, 0, resized_mask)
|
| 670 |
+
resized_mask = np.where(np.logical_and(resized_mask > 31, resized_mask <= 95), 63, resized_mask)
|
| 671 |
+
resized_mask = np.where(np.logical_and(resized_mask > 95, resized_mask <= 191), 127, resized_mask)
|
| 672 |
+
resized_mask = np.where(resized_mask > 191, 255, resized_mask)
|
| 673 |
+
else:
|
| 674 |
+
# Trimask mode: 3 values [0, 127, 255]
|
| 675 |
+
resized_mask = np.where(np.logical_and(resized_mask > 63, resized_mask < 192), 127, resized_mask)
|
| 676 |
+
resized_mask = np.where(resized_mask >= 192, 255, resized_mask)
|
| 677 |
+
resized_mask = np.where(resized_mask <= 63, 0, resized_mask)
|
| 678 |
+
|
| 679 |
+
resized_inputs.append(resized_input)
|
| 680 |
+
resized_targets.append(resized_target)
|
| 681 |
+
resized_masks.append(resized_mask)
|
| 682 |
+
|
| 683 |
+
if depth_video is not None:
|
| 684 |
+
resized_depth = resize_frame(depth_video[i], self.larger_side_of_image_and_video)
|
| 685 |
+
resized_depths.append(resized_depth)
|
| 686 |
+
|
| 687 |
+
input_video = np.array(resized_inputs)
|
| 688 |
+
target_video = np.array(resized_targets)
|
| 689 |
+
mask_video = np.array(resized_masks)
|
| 690 |
+
if depth_video is not None:
|
| 691 |
+
depth_video = np.array(resized_depths)
|
| 692 |
+
|
| 693 |
+
if len(mask_video.shape) == 3:
|
| 694 |
+
mask_video = mask_video[..., None]
|
| 695 |
+
if mask_video.shape[-1] == 3:
|
| 696 |
+
mask_video = mask_video[..., :1]
|
| 697 |
+
if len(mask_video.shape) != 4:
|
| 698 |
+
raise ValueError(f"mask_video shape is {mask_video.shape}.")
|
| 699 |
+
|
| 700 |
+
text = data_info['text']
|
| 701 |
+
print(f"DEBUG DATASET: Converting to tensors (enable_bucket={self.enable_bucket})...")
|
| 702 |
+
if not self.enable_bucket:
|
| 703 |
+
print(f"DEBUG DATASET: Converting input_video to tensor...")
|
| 704 |
+
input_video = torch.from_numpy(input_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 705 |
+
print(f"DEBUG DATASET: Converting target_video to tensor...")
|
| 706 |
+
target_video = torch.from_numpy(target_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 707 |
+
print(f"DEBUG DATASET: Converting mask_video to tensor...")
|
| 708 |
+
mask_video = torch.from_numpy(mask_video).permute(0, 3, 1, 2).contiguous() / 255.
|
| 709 |
+
|
| 710 |
+
# Process depth video if available
|
| 711 |
+
if depth_video is not None:
|
| 712 |
+
print(f"DEBUG DATASET: Processing depth_video...")
|
| 713 |
+
# IMPORTANT: Copy depth_video to ensure it's not memory-mapped
|
| 714 |
+
# Memory-mapped files can cause bus errors on GPU transfer
|
| 715 |
+
print(f"DEBUG DATASET: Copying depth_video to ensure not memory-mapped...")
|
| 716 |
+
depth_video = np.array(depth_video, copy=True)
|
| 717 |
+
print(f"DEBUG DATASET: depth_video copied, shape={depth_video.shape}")
|
| 718 |
+
|
| 719 |
+
# Ensure depth has correct shape
|
| 720 |
+
if len(depth_video.shape) == 3:
|
| 721 |
+
depth_video = depth_video[..., None]
|
| 722 |
+
if depth_video.shape[-1] == 3:
|
| 723 |
+
# Convert to grayscale if RGB
|
| 724 |
+
print(f"DEBUG DATASET: Converting depth to grayscale...")
|
| 725 |
+
depth_video = depth_video.mean(axis=-1, keepdims=True)
|
| 726 |
+
# Convert to tensor [F, 1, H, W] and normalize to [0, 1]
|
| 727 |
+
print(f"DEBUG DATASET: Converting depth to tensor...")
|
| 728 |
+
depth_video = torch.from_numpy(depth_video).permute(0, 3, 1, 2).contiguous().float() / 255.
|
| 729 |
+
# Ensure tensor is contiguous and owned
|
| 730 |
+
print(f"DEBUG DATASET: Cloning depth tensor...")
|
| 731 |
+
depth_video = depth_video.clone().contiguous()
|
| 732 |
+
print(f"DEBUG DATASET: depth_video final shape: {depth_video.shape}, is_contiguous: {depth_video.is_contiguous()}")
|
| 733 |
+
|
| 734 |
+
# Apply transforms to each video separately (they expect 3 channels)
|
| 735 |
+
print(f"DEBUG DATASET: Applying video transforms...")
|
| 736 |
+
input_video = self.video_transforms(input_video)
|
| 737 |
+
target_video = self.video_transforms(target_video)
|
| 738 |
+
# Don't normalize mask since it's single channel
|
| 739 |
+
print(f"DEBUG DATASET: Normalizing mask_video...")
|
| 740 |
+
mask_video = mask_video * 2.0 - 1.0 # Scale to [-1, 1] like other channels
|
| 741 |
+
print(f"DEBUG DATASET: All tensors ready (non-bucket mode)")
|
| 742 |
+
|
| 743 |
+
else:
|
| 744 |
+
# For bucket mode, keep as numpy until collate
|
| 745 |
+
# Collate function expects [0, 255] range and will normalize
|
| 746 |
+
print(f"DEBUG DATASET: Bucket mode - keeping as numpy in [0, 255] range...")
|
| 747 |
+
print(f"DEBUG DATASET: All numpy arrays ready (bucket mode)")
|
| 748 |
+
|
| 749 |
+
# Load warped noise - REQUIRED if specified in dataset
|
| 750 |
+
warped_noise = None
|
| 751 |
+
if 'warped_noise_path' in data_info:
|
| 752 |
+
warped_noise_dir = data_info['warped_noise_path'] if self.data_root is None else os.path.join(self.data_root, data_info['warped_noise_path'])
|
| 753 |
+
noise_path = os.path.join(warped_noise_dir, 'noises.npy')
|
| 754 |
+
|
| 755 |
+
if not os.path.exists(noise_path):
|
| 756 |
+
raise FileNotFoundError(
|
| 757 |
+
f"Warped noise path specified in dataset but file not found: {noise_path}\n"
|
| 758 |
+
f"Make sure you've generated warped noise for all videos in the dataset."
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
try:
|
| 762 |
+
warped_noise = np.load(noise_path) # Shape: (T, C, H, W) in float16
|
| 763 |
+
warped_noise = torch.from_numpy(warped_noise).float() # Convert to torch tensor
|
| 764 |
+
except Exception as e:
|
| 765 |
+
raise RuntimeError(
|
| 766 |
+
f"Failed to load warped noise from {noise_path}: {e}\n"
|
| 767 |
+
f"The noise file may be corrupted. Try regenerating it."
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
# Random use no text generation
|
| 771 |
+
if random.random() < self.text_drop_ratio:
|
| 772 |
+
text = ''
|
| 773 |
+
|
| 774 |
+
if self.trimask_zeroout_removal:
|
| 775 |
+
input_video = input_video * np.where(mask_video > 200, 0, 1).astype(input_video.dtype)
|
| 776 |
+
|
| 777 |
+
result = {
|
| 778 |
+
'pixel_values': target_video,
|
| 779 |
+
'input_condition': input_video,
|
| 780 |
+
'mask': mask_video,
|
| 781 |
+
'text': text,
|
| 782 |
+
'data_type': 'video_mask_tuple',
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
# Add depth maps if available
|
| 786 |
+
if depth_video is not None:
|
| 787 |
+
result['depth_maps'] = depth_video
|
| 788 |
+
|
| 789 |
+
# Add warped noise to batch if available
|
| 790 |
+
if warped_noise is not None:
|
| 791 |
+
result['warped_noise'] = warped_noise
|
| 792 |
+
|
| 793 |
+
return result
|
| 794 |
+
|
| 795 |
+
else:
|
| 796 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 797 |
+
if self.data_root is not None:
|
| 798 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 799 |
+
image = Image.open(image_path).convert('RGB')
|
| 800 |
+
if not self.enable_bucket:
|
| 801 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 802 |
+
else:
|
| 803 |
+
image = np.expand_dims(np.array(image), 0)
|
| 804 |
+
if random.random() < self.text_drop_ratio:
|
| 805 |
+
text = ''
|
| 806 |
+
return {
|
| 807 |
+
'pixel_values': image,
|
| 808 |
+
'text': text,
|
| 809 |
+
'data_type': 'image',
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
def __len__(self):
|
| 813 |
+
return self.length
|
| 814 |
+
|
| 815 |
+
def __getitem__(self, idx):
|
| 816 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 817 |
+
data_type = data_info.get('type', 'image')
|
| 818 |
+
while True:
|
| 819 |
+
sample = {}
|
| 820 |
+
try:
|
| 821 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 822 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 823 |
+
if data_type_local != data_type:
|
| 824 |
+
raise ValueError("data_type_local != data_type")
|
| 825 |
+
|
| 826 |
+
sample = self.get_batch(idx)
|
| 827 |
+
sample["idx"] = idx
|
| 828 |
+
|
| 829 |
+
if len(sample) > 0:
|
| 830 |
+
break
|
| 831 |
+
except Exception as e:
|
| 832 |
+
import traceback
|
| 833 |
+
print(f"Error loading sample at index {idx}:")
|
| 834 |
+
print(f"Data info: {self.dataset[idx % len(self.dataset)]}")
|
| 835 |
+
print(f"Error: {e}")
|
| 836 |
+
traceback.print_exc()
|
| 837 |
+
idx = random.randint(0, self.length-1)
|
| 838 |
+
|
| 839 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 840 |
+
if "mask" not in sample:
|
| 841 |
+
mask = get_random_mask_multi(sample["pixel_values"].size())
|
| 842 |
+
sample["mask"] = mask
|
| 843 |
+
else:
|
| 844 |
+
mask = sample["mask"]
|
| 845 |
+
|
| 846 |
+
if "input_condition" in sample:
|
| 847 |
+
mask_pixel_values = sample["input_condition"]
|
| 848 |
+
else:
|
| 849 |
+
mask_pixel_values = sample["pixel_values"]
|
| 850 |
+
mask_pixel_values = mask_pixel_values * (1 - mask) + torch.ones_like(mask_pixel_values) * -1 * mask
|
| 851 |
+
|
| 852 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 853 |
+
|
| 854 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 855 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 856 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 857 |
+
|
| 858 |
+
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
|
| 859 |
+
if (mask == 1).all():
|
| 860 |
+
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
|
| 861 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
| 862 |
+
|
| 863 |
+
return sample
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
class ImageVideoControlDataset(Dataset):
|
| 867 |
+
def __init__(
|
| 868 |
+
self,
|
| 869 |
+
ann_path, data_root=None,
|
| 870 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 871 |
+
image_sample_size=512,
|
| 872 |
+
video_repeat=0,
|
| 873 |
+
text_drop_ratio=0.1,
|
| 874 |
+
enable_bucket=False,
|
| 875 |
+
video_length_drop_start=0.0,
|
| 876 |
+
video_length_drop_end=1.0,
|
| 877 |
+
enable_inpaint=False,
|
| 878 |
+
):
|
| 879 |
+
# Loading annotations from files
|
| 880 |
+
print(f"loading annotations from {ann_path} ...")
|
| 881 |
+
if ann_path.endswith('.csv'):
|
| 882 |
+
with open(ann_path, 'r') as csvfile:
|
| 883 |
+
dataset = list(csv.DictReader(csvfile))
|
| 884 |
+
elif ann_path.endswith('.json'):
|
| 885 |
+
dataset = json.load(open(ann_path))
|
| 886 |
+
else:
|
| 887 |
+
raise ValueError(f"Unsupported annotation file format: {ann_path}. Only .csv and .json files are supported.")
|
| 888 |
+
|
| 889 |
+
self.data_root = data_root
|
| 890 |
+
|
| 891 |
+
# It's used to balance num of images and videos.
|
| 892 |
+
self.dataset = []
|
| 893 |
+
for data in dataset:
|
| 894 |
+
if data.get('type', 'image') != 'video':
|
| 895 |
+
self.dataset.append(data)
|
| 896 |
+
if video_repeat > 0:
|
| 897 |
+
for _ in range(video_repeat):
|
| 898 |
+
for data in dataset:
|
| 899 |
+
if data.get('type', 'image') == 'video':
|
| 900 |
+
self.dataset.append(data)
|
| 901 |
+
del dataset
|
| 902 |
+
|
| 903 |
+
self.length = len(self.dataset)
|
| 904 |
+
print(f"data scale: {self.length}")
|
| 905 |
+
# TODO: enable bucket training
|
| 906 |
+
self.enable_bucket = enable_bucket
|
| 907 |
+
self.text_drop_ratio = text_drop_ratio
|
| 908 |
+
self.enable_inpaint = enable_inpaint
|
| 909 |
+
|
| 910 |
+
self.video_length_drop_start = video_length_drop_start
|
| 911 |
+
self.video_length_drop_end = video_length_drop_end
|
| 912 |
+
|
| 913 |
+
# Video params
|
| 914 |
+
self.video_sample_stride = video_sample_stride
|
| 915 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 916 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 917 |
+
self.video_transforms = transforms.Compose(
|
| 918 |
+
[
|
| 919 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 920 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 921 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 922 |
+
]
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
# Image params
|
| 926 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 927 |
+
self.image_transforms = transforms.Compose([
|
| 928 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 929 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 930 |
+
transforms.ToTensor(),
|
| 931 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 932 |
+
])
|
| 933 |
+
|
| 934 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 935 |
+
|
| 936 |
+
def get_batch(self, idx):
|
| 937 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 938 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 939 |
+
|
| 940 |
+
if data_info.get('type', 'image')=='video':
|
| 941 |
+
if self.data_root is None:
|
| 942 |
+
video_dir = video_id
|
| 943 |
+
else:
|
| 944 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 945 |
+
|
| 946 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 947 |
+
min_sample_n_frames = min(
|
| 948 |
+
self.video_sample_n_frames,
|
| 949 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 950 |
+
)
|
| 951 |
+
if min_sample_n_frames == 0:
|
| 952 |
+
raise ValueError(f"No Frames in video.")
|
| 953 |
+
|
| 954 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 955 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 956 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 957 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 958 |
+
|
| 959 |
+
try:
|
| 960 |
+
sample_args = (video_reader, batch_index)
|
| 961 |
+
pixel_values = func_timeout(
|
| 962 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 963 |
+
)
|
| 964 |
+
resized_frames = []
|
| 965 |
+
for i in range(len(pixel_values)):
|
| 966 |
+
frame = pixel_values[i]
|
| 967 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 968 |
+
resized_frames.append(resized_frame)
|
| 969 |
+
pixel_values = np.array(resized_frames)
|
| 970 |
+
except FunctionTimedOut:
|
| 971 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 972 |
+
except Exception as e:
|
| 973 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 974 |
+
|
| 975 |
+
if not self.enable_bucket:
|
| 976 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 977 |
+
pixel_values = pixel_values / 255.
|
| 978 |
+
del video_reader
|
| 979 |
+
else:
|
| 980 |
+
pixel_values = pixel_values
|
| 981 |
+
|
| 982 |
+
if not self.enable_bucket:
|
| 983 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 984 |
+
|
| 985 |
+
# Random use no text generation
|
| 986 |
+
if random.random() < self.text_drop_ratio:
|
| 987 |
+
text = ''
|
| 988 |
+
|
| 989 |
+
control_video_id = data_info['control_file_path']
|
| 990 |
+
|
| 991 |
+
if self.data_root is None:
|
| 992 |
+
control_video_id = control_video_id
|
| 993 |
+
else:
|
| 994 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 995 |
+
|
| 996 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 997 |
+
try:
|
| 998 |
+
sample_args = (control_video_reader, batch_index)
|
| 999 |
+
control_pixel_values = func_timeout(
|
| 1000 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 1001 |
+
)
|
| 1002 |
+
resized_frames = []
|
| 1003 |
+
for i in range(len(control_pixel_values)):
|
| 1004 |
+
frame = control_pixel_values[i]
|
| 1005 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 1006 |
+
resized_frames.append(resized_frame)
|
| 1007 |
+
control_pixel_values = np.array(resized_frames)
|
| 1008 |
+
except FunctionTimedOut:
|
| 1009 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 1010 |
+
except Exception as e:
|
| 1011 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 1012 |
+
|
| 1013 |
+
if not self.enable_bucket:
|
| 1014 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 1015 |
+
control_pixel_values = control_pixel_values / 255.
|
| 1016 |
+
del control_video_reader
|
| 1017 |
+
else:
|
| 1018 |
+
control_pixel_values = control_pixel_values
|
| 1019 |
+
|
| 1020 |
+
if not self.enable_bucket:
|
| 1021 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 1022 |
+
return pixel_values, control_pixel_values, text, "video"
|
| 1023 |
+
else:
|
| 1024 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 1025 |
+
if self.data_root is not None:
|
| 1026 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 1027 |
+
image = Image.open(image_path).convert('RGB')
|
| 1028 |
+
if not self.enable_bucket:
|
| 1029 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 1030 |
+
else:
|
| 1031 |
+
image = np.expand_dims(np.array(image), 0)
|
| 1032 |
+
|
| 1033 |
+
if random.random() < self.text_drop_ratio:
|
| 1034 |
+
text = ''
|
| 1035 |
+
|
| 1036 |
+
control_image_id = data_info['control_file_path']
|
| 1037 |
+
|
| 1038 |
+
if self.data_root is None:
|
| 1039 |
+
control_image_id = control_image_id
|
| 1040 |
+
else:
|
| 1041 |
+
control_image_id = os.path.join(self.data_root, control_image_id)
|
| 1042 |
+
|
| 1043 |
+
control_image = Image.open(control_image_id).convert('RGB')
|
| 1044 |
+
if not self.enable_bucket:
|
| 1045 |
+
control_image = self.image_transforms(control_image).unsqueeze(0)
|
| 1046 |
+
else:
|
| 1047 |
+
control_image = np.expand_dims(np.array(control_image), 0)
|
| 1048 |
+
return image, control_image, text, 'image'
|
| 1049 |
+
|
| 1050 |
+
def __len__(self):
|
| 1051 |
+
return self.length
|
| 1052 |
+
|
| 1053 |
+
def __getitem__(self, idx):
|
| 1054 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 1055 |
+
data_type = data_info.get('type', 'image')
|
| 1056 |
+
while True:
|
| 1057 |
+
sample = {}
|
| 1058 |
+
try:
|
| 1059 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 1060 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 1061 |
+
if data_type_local != data_type:
|
| 1062 |
+
raise ValueError("data_type_local != data_type")
|
| 1063 |
+
|
| 1064 |
+
pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
|
| 1065 |
+
sample["pixel_values"] = pixel_values
|
| 1066 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 1067 |
+
sample["text"] = name
|
| 1068 |
+
sample["data_type"] = data_type
|
| 1069 |
+
sample["idx"] = idx
|
| 1070 |
+
|
| 1071 |
+
if len(sample) > 0:
|
| 1072 |
+
break
|
| 1073 |
+
except Exception as e:
|
| 1074 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 1075 |
+
idx = random.randint(0, self.length-1)
|
| 1076 |
+
|
| 1077 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 1078 |
+
mask = get_random_mask(pixel_values.size())
|
| 1079 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 1080 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 1081 |
+
sample["mask"] = mask
|
| 1082 |
+
|
| 1083 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 1084 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 1085 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 1086 |
+
|
| 1087 |
+
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
|
| 1088 |
+
if (mask == 1).all():
|
| 1089 |
+
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
|
| 1090 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
| 1091 |
+
|
| 1092 |
+
return sample
|
videox_fun/data/dataset_video.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from threading import Thread
|
| 10 |
+
|
| 11 |
+
import albumentations
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torchvision.transforms as transforms
|
| 16 |
+
from decord import VideoReader
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 21 |
+
from torch.utils.data.dataset import Dataset
|
| 22 |
+
|
| 23 |
+
VIDEO_READER_TIMEOUT = 20
|
| 24 |
+
|
| 25 |
+
def get_random_mask(shape):
|
| 26 |
+
f, c, h, w = shape
|
| 27 |
+
|
| 28 |
+
mask_index = np.random.randint(0, 4)
|
| 29 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
| 30 |
+
if mask_index == 0:
|
| 31 |
+
mask[1:, :, :, :] = 1
|
| 32 |
+
elif mask_index == 1:
|
| 33 |
+
mask_frame_index = 1
|
| 34 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
| 35 |
+
elif mask_index == 2:
|
| 36 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 37 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 38 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 39 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 40 |
+
|
| 41 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 42 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 43 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 44 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 45 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
| 46 |
+
elif mask_index == 3:
|
| 47 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 48 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 49 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 50 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 51 |
+
|
| 52 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 53 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 54 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 55 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 56 |
+
|
| 57 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
| 58 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
| 59 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"The mask_index {mask_index} is not define")
|
| 62 |
+
return mask
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@contextmanager
|
| 66 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
| 67 |
+
vr = VideoReader(*args, **kwargs)
|
| 68 |
+
try:
|
| 69 |
+
yield vr
|
| 70 |
+
finally:
|
| 71 |
+
del vr
|
| 72 |
+
gc.collect()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_video_reader_batch(video_reader, batch_index):
|
| 76 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
| 77 |
+
return frames
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class WebVid10M(Dataset):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
csv_path, video_folder,
|
| 84 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
| 85 |
+
enable_bucket=False, enable_inpaint=False, is_image=False,
|
| 86 |
+
):
|
| 87 |
+
print(f"loading annotations from {csv_path} ...")
|
| 88 |
+
with open(csv_path, 'r') as csvfile:
|
| 89 |
+
self.dataset = list(csv.DictReader(csvfile))
|
| 90 |
+
self.length = len(self.dataset)
|
| 91 |
+
print(f"data scale: {self.length}")
|
| 92 |
+
|
| 93 |
+
self.video_folder = video_folder
|
| 94 |
+
self.sample_stride = sample_stride
|
| 95 |
+
self.sample_n_frames = sample_n_frames
|
| 96 |
+
self.enable_bucket = enable_bucket
|
| 97 |
+
self.enable_inpaint = enable_inpaint
|
| 98 |
+
self.is_image = is_image
|
| 99 |
+
|
| 100 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
| 101 |
+
self.pixel_transforms = transforms.Compose([
|
| 102 |
+
transforms.Resize(sample_size[0]),
|
| 103 |
+
transforms.CenterCrop(sample_size),
|
| 104 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 105 |
+
])
|
| 106 |
+
|
| 107 |
+
def get_batch(self, idx):
|
| 108 |
+
video_dict = self.dataset[idx]
|
| 109 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
| 110 |
+
|
| 111 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
| 112 |
+
video_reader = VideoReader(video_dir)
|
| 113 |
+
video_length = len(video_reader)
|
| 114 |
+
|
| 115 |
+
if not self.is_image:
|
| 116 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
| 117 |
+
start_idx = random.randint(0, video_length - clip_length)
|
| 118 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
| 119 |
+
else:
|
| 120 |
+
batch_index = [random.randint(0, video_length - 1)]
|
| 121 |
+
|
| 122 |
+
if not self.enable_bucket:
|
| 123 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
| 124 |
+
pixel_values = pixel_values / 255.
|
| 125 |
+
del video_reader
|
| 126 |
+
else:
|
| 127 |
+
pixel_values = video_reader.get_batch(batch_index).asnumpy()
|
| 128 |
+
|
| 129 |
+
if self.is_image:
|
| 130 |
+
pixel_values = pixel_values[0]
|
| 131 |
+
return pixel_values, name
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return self.length
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
while True:
|
| 138 |
+
try:
|
| 139 |
+
pixel_values, name = self.get_batch(idx)
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print("Error info:", e)
|
| 144 |
+
idx = random.randint(0, self.length-1)
|
| 145 |
+
|
| 146 |
+
if not self.enable_bucket:
|
| 147 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 148 |
+
if self.enable_inpaint:
|
| 149 |
+
mask = get_random_mask(pixel_values.size())
|
| 150 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 151 |
+
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
|
| 152 |
+
else:
|
| 153 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 154 |
+
return sample
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class VideoDataset(Dataset):
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
json_path, video_folder=None,
|
| 161 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
| 162 |
+
enable_bucket=False, enable_inpaint=False
|
| 163 |
+
):
|
| 164 |
+
print(f"loading annotations from {json_path} ...")
|
| 165 |
+
self.dataset = json.load(open(json_path, 'r'))
|
| 166 |
+
self.length = len(self.dataset)
|
| 167 |
+
print(f"data scale: {self.length}")
|
| 168 |
+
|
| 169 |
+
self.video_folder = video_folder
|
| 170 |
+
self.sample_stride = sample_stride
|
| 171 |
+
self.sample_n_frames = sample_n_frames
|
| 172 |
+
self.enable_bucket = enable_bucket
|
| 173 |
+
self.enable_inpaint = enable_inpaint
|
| 174 |
+
|
| 175 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
| 176 |
+
self.pixel_transforms = transforms.Compose(
|
| 177 |
+
[
|
| 178 |
+
transforms.Resize(sample_size[0]),
|
| 179 |
+
transforms.CenterCrop(sample_size),
|
| 180 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 181 |
+
]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def get_batch(self, idx):
|
| 185 |
+
video_dict = self.dataset[idx]
|
| 186 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 187 |
+
|
| 188 |
+
if self.video_folder is None:
|
| 189 |
+
video_dir = video_id
|
| 190 |
+
else:
|
| 191 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 192 |
+
|
| 193 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 194 |
+
video_length = len(video_reader)
|
| 195 |
+
|
| 196 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
| 197 |
+
start_idx = random.randint(0, video_length - clip_length)
|
| 198 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
sample_args = (video_reader, batch_index)
|
| 202 |
+
pixel_values = func_timeout(
|
| 203 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 204 |
+
)
|
| 205 |
+
except FunctionTimedOut:
|
| 206 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 207 |
+
except Exception as e:
|
| 208 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 209 |
+
|
| 210 |
+
if not self.enable_bucket:
|
| 211 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 212 |
+
pixel_values = pixel_values / 255.
|
| 213 |
+
del video_reader
|
| 214 |
+
else:
|
| 215 |
+
pixel_values = pixel_values
|
| 216 |
+
|
| 217 |
+
return pixel_values, name
|
| 218 |
+
|
| 219 |
+
def __len__(self):
|
| 220 |
+
return self.length
|
| 221 |
+
|
| 222 |
+
def __getitem__(self, idx):
|
| 223 |
+
while True:
|
| 224 |
+
try:
|
| 225 |
+
pixel_values, name = self.get_batch(idx)
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
print("Error info:", e)
|
| 230 |
+
idx = random.randint(0, self.length-1)
|
| 231 |
+
|
| 232 |
+
if not self.enable_bucket:
|
| 233 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 234 |
+
if self.enable_inpaint:
|
| 235 |
+
mask = get_random_mask(pixel_values.size())
|
| 236 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 237 |
+
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
|
| 238 |
+
else:
|
| 239 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 240 |
+
return sample
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
if 1:
|
| 245 |
+
dataset = VideoDataset(
|
| 246 |
+
json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
|
| 247 |
+
sample_size=256,
|
| 248 |
+
sample_stride=4, sample_n_frames=16,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
if 0:
|
| 252 |
+
dataset = WebVid10M(
|
| 253 |
+
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
|
| 254 |
+
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
|
| 255 |
+
sample_size=256,
|
| 256 |
+
sample_stride=4, sample_n_frames=16,
|
| 257 |
+
is_image=False,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
| 261 |
+
for idx, batch in enumerate(dataloader):
|
| 262 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
videox_fun/dist/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import xfuser
|
| 6 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 7 |
+
get_sequence_parallel_world_size,
|
| 8 |
+
get_sp_group, get_world_group,
|
| 9 |
+
init_distributed_environment,
|
| 10 |
+
initialize_model_parallel)
|
| 11 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 12 |
+
except Exception as ex:
|
| 13 |
+
get_sequence_parallel_world_size = None
|
| 14 |
+
get_sequence_parallel_rank = None
|
| 15 |
+
xFuserLongContextAttention = None
|
| 16 |
+
get_sp_group = None
|
| 17 |
+
get_world_group = None
|
| 18 |
+
init_distributed_environment = None
|
| 19 |
+
initialize_model_parallel = None
|
| 20 |
+
|
| 21 |
+
def set_multi_gpus_devices(ulysses_degree, ring_degree):
|
| 22 |
+
if ulysses_degree > 1 or ring_degree > 1:
|
| 23 |
+
if get_sp_group is None:
|
| 24 |
+
raise RuntimeError("xfuser is not installed.")
|
| 25 |
+
dist.init_process_group("nccl")
|
| 26 |
+
print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
|
| 27 |
+
ulysses_degree, ring_degree, dist.get_rank(),
|
| 28 |
+
dist.get_world_size()))
|
| 29 |
+
assert dist.get_world_size() == ring_degree * ulysses_degree, \
|
| 30 |
+
"number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
|
| 31 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
| 32 |
+
initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
|
| 33 |
+
ring_degree=ring_degree,
|
| 34 |
+
ulysses_degree=ulysses_degree)
|
| 35 |
+
# device = torch.device("cuda:%d" % dist.get_rank())
|
| 36 |
+
device = torch.device(f"cuda:{get_world_group().local_rank}")
|
| 37 |
+
print('rank=%d device=%s' % (get_world_group().rank, str(device)))
|
| 38 |
+
else:
|
| 39 |
+
device = "cuda"
|
| 40 |
+
return device
|
videox_fun/dist/cogvideox_xfuser.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention import Attention
|
| 6 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import xfuser
|
| 10 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 11 |
+
get_sequence_parallel_world_size,
|
| 12 |
+
get_sp_group,
|
| 13 |
+
init_distributed_environment,
|
| 14 |
+
initialize_model_parallel)
|
| 15 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 16 |
+
except Exception as ex:
|
| 17 |
+
get_sequence_parallel_world_size = None
|
| 18 |
+
get_sequence_parallel_rank = None
|
| 19 |
+
xFuserLongContextAttention = None
|
| 20 |
+
get_sp_group = None
|
| 21 |
+
init_distributed_environment = None
|
| 22 |
+
initialize_model_parallel = None
|
| 23 |
+
|
| 24 |
+
class CogVideoXMultiGPUsAttnProcessor2_0:
|
| 25 |
+
r"""
|
| 26 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 27 |
+
query and key vectors, but does not include spatial normalization.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
if xFuserLongContextAttention is not None:
|
| 32 |
+
try:
|
| 33 |
+
self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
|
| 34 |
+
except Exception:
|
| 35 |
+
self.hybrid_seq_parallel_attn = None
|
| 36 |
+
else:
|
| 37 |
+
self.hybrid_seq_parallel_attn = None
|
| 38 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 39 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 40 |
+
|
| 41 |
+
def __call__(
|
| 42 |
+
self,
|
| 43 |
+
attn: Attention,
|
| 44 |
+
hidden_states: torch.Tensor,
|
| 45 |
+
encoder_hidden_states: torch.Tensor,
|
| 46 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 47 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 48 |
+
) -> torch.Tensor:
|
| 49 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 50 |
+
|
| 51 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 52 |
+
|
| 53 |
+
batch_size, sequence_length, _ = (
|
| 54 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if attention_mask is not None:
|
| 58 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 59 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 60 |
+
|
| 61 |
+
query = attn.to_q(hidden_states)
|
| 62 |
+
key = attn.to_k(hidden_states)
|
| 63 |
+
value = attn.to_v(hidden_states)
|
| 64 |
+
|
| 65 |
+
inner_dim = key.shape[-1]
|
| 66 |
+
head_dim = inner_dim // attn.heads
|
| 67 |
+
|
| 68 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 69 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 70 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 71 |
+
|
| 72 |
+
if attn.norm_q is not None:
|
| 73 |
+
query = attn.norm_q(query)
|
| 74 |
+
if attn.norm_k is not None:
|
| 75 |
+
key = attn.norm_k(key)
|
| 76 |
+
|
| 77 |
+
# Apply RoPE if needed
|
| 78 |
+
if image_rotary_emb is not None:
|
| 79 |
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
| 80 |
+
if not attn.is_cross_attention:
|
| 81 |
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
| 82 |
+
|
| 83 |
+
if self.hybrid_seq_parallel_attn is None:
|
| 84 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 85 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 86 |
+
)
|
| 87 |
+
hidden_states = hidden_states
|
| 88 |
+
else:
|
| 89 |
+
img_q = query[:, :, text_seq_length:].transpose(1, 2)
|
| 90 |
+
txt_q = query[:, :, :text_seq_length].transpose(1, 2)
|
| 91 |
+
img_k = key[:, :, text_seq_length:].transpose(1, 2)
|
| 92 |
+
txt_k = key[:, :, :text_seq_length].transpose(1, 2)
|
| 93 |
+
img_v = value[:, :, text_seq_length:].transpose(1, 2)
|
| 94 |
+
txt_v = value[:, :, :text_seq_length].transpose(1, 2)
|
| 95 |
+
|
| 96 |
+
hidden_states = self.hybrid_seq_parallel_attn(
|
| 97 |
+
None,
|
| 98 |
+
img_q, img_k, img_v, dropout_p=0.0, causal=False,
|
| 99 |
+
joint_tensor_query=txt_q,
|
| 100 |
+
joint_tensor_key=txt_k,
|
| 101 |
+
joint_tensor_value=txt_v,
|
| 102 |
+
joint_strategy='front',
|
| 103 |
+
).transpose(1, 2)
|
| 104 |
+
|
| 105 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 106 |
+
|
| 107 |
+
# linear proj
|
| 108 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 109 |
+
# dropout
|
| 110 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 111 |
+
|
| 112 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
| 113 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
| 114 |
+
)
|
| 115 |
+
return hidden_states, encoder_hidden_states
|
| 116 |
+
|
videox_fun/dist/wan_xfuser.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.cuda.amp as amp
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import xfuser
|
| 6 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 7 |
+
get_sequence_parallel_world_size,
|
| 8 |
+
get_sp_group,
|
| 9 |
+
init_distributed_environment,
|
| 10 |
+
initialize_model_parallel)
|
| 11 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 12 |
+
except Exception as ex:
|
| 13 |
+
get_sequence_parallel_world_size = None
|
| 14 |
+
get_sequence_parallel_rank = None
|
| 15 |
+
xFuserLongContextAttention = None
|
| 16 |
+
get_sp_group = None
|
| 17 |
+
init_distributed_environment = None
|
| 18 |
+
initialize_model_parallel = None
|
| 19 |
+
|
| 20 |
+
def pad_freqs(original_tensor, target_len):
|
| 21 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 22 |
+
pad_size = target_len - seq_len
|
| 23 |
+
padding_tensor = torch.ones(
|
| 24 |
+
pad_size,
|
| 25 |
+
s1,
|
| 26 |
+
s2,
|
| 27 |
+
dtype=original_tensor.dtype,
|
| 28 |
+
device=original_tensor.device)
|
| 29 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 30 |
+
return padded_tensor
|
| 31 |
+
|
| 32 |
+
@amp.autocast(enabled=False)
|
| 33 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 34 |
+
"""
|
| 35 |
+
x: [B, L, N, C].
|
| 36 |
+
grid_sizes: [B, 3].
|
| 37 |
+
freqs: [M, C // 2].
|
| 38 |
+
"""
|
| 39 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 40 |
+
# split freqs
|
| 41 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 42 |
+
|
| 43 |
+
# loop over samples
|
| 44 |
+
output = []
|
| 45 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 46 |
+
seq_len = f * h * w
|
| 47 |
+
|
| 48 |
+
# precompute multipliers
|
| 49 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
|
| 50 |
+
s, n, -1, 2))
|
| 51 |
+
freqs_i = torch.cat([
|
| 52 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 53 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 54 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 55 |
+
],
|
| 56 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 57 |
+
|
| 58 |
+
# apply rotary embedding
|
| 59 |
+
sp_size = get_sequence_parallel_world_size()
|
| 60 |
+
sp_rank = get_sequence_parallel_rank()
|
| 61 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 62 |
+
s_per_rank = s
|
| 63 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 64 |
+
s_per_rank), :, :]
|
| 65 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 66 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 67 |
+
|
| 68 |
+
# append to collection
|
| 69 |
+
output.append(x_i)
|
| 70 |
+
return torch.stack(output)
|
| 71 |
+
|
| 72 |
+
def usp_attn_forward(self,
|
| 73 |
+
x,
|
| 74 |
+
seq_lens,
|
| 75 |
+
grid_sizes,
|
| 76 |
+
freqs,
|
| 77 |
+
dtype=torch.bfloat16):
|
| 78 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 79 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 80 |
+
|
| 81 |
+
def half(x):
|
| 82 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 83 |
+
|
| 84 |
+
# query, key, value function
|
| 85 |
+
def qkv_fn(x):
|
| 86 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 87 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 88 |
+
v = self.v(x).view(b, s, n, d)
|
| 89 |
+
return q, k, v
|
| 90 |
+
|
| 91 |
+
q, k, v = qkv_fn(x)
|
| 92 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 93 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 94 |
+
|
| 95 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 96 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 97 |
+
# if k_lens is not None:
|
| 98 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 99 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 100 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 101 |
+
|
| 102 |
+
x = xFuserLongContextAttention()(
|
| 103 |
+
None,
|
| 104 |
+
query=half(q),
|
| 105 |
+
key=half(k),
|
| 106 |
+
value=half(v),
|
| 107 |
+
window_size=self.window_size)
|
| 108 |
+
|
| 109 |
+
# TODO: padding after attention.
|
| 110 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 111 |
+
|
| 112 |
+
# output
|
| 113 |
+
x = x.flatten(2)
|
| 114 |
+
x = self.o(x)
|
| 115 |
+
return x
|
videox_fun/models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
|
| 2 |
+
|
| 3 |
+
from .cogvideox_transformer3d import CogVideoXTransformer3DModel
|
| 4 |
+
from .cogvideox_vae import AutoencoderKLCogVideoX
|
videox_fun/models/cache_utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_teacache_coefficients(model_name):
|
| 6 |
+
if "wan2.1-t2v-1.3b" in model_name.lower() or "wan2.1-fun-1.3b" in model_name.lower():
|
| 7 |
+
return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
|
| 8 |
+
elif "wan2.1-t2v-14b" in model_name.lower():
|
| 9 |
+
return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
|
| 10 |
+
elif "wan2.1-i2v-14b-480p" in model_name.lower():
|
| 11 |
+
return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
| 12 |
+
elif "wan2.1-i2v-14b-720p" in model_name.lower() or "wan2.1-fun-14b" in model_name.lower():
|
| 13 |
+
return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
| 14 |
+
else:
|
| 15 |
+
print(f"The model {model_name} is not supported by TeaCache.")
|
| 16 |
+
return None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TeaCache():
|
| 20 |
+
"""
|
| 21 |
+
Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
|
| 22 |
+
the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
|
| 23 |
+
Please refer to:
|
| 24 |
+
1. https://github.com/ali-vilab/TeaCache.
|
| 25 |
+
2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
|
| 26 |
+
"""
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
coefficients: list[float],
|
| 30 |
+
num_steps: int,
|
| 31 |
+
rel_l1_thresh: float = 0.0,
|
| 32 |
+
num_skip_start_steps: int = 0,
|
| 33 |
+
offload: bool = True,
|
| 34 |
+
):
|
| 35 |
+
if num_steps < 1:
|
| 36 |
+
raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
|
| 37 |
+
if rel_l1_thresh < 0:
|
| 38 |
+
raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
|
| 39 |
+
if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"`num_skip_start_steps` must be great than or equal to 0 and "
|
| 42 |
+
f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
|
| 43 |
+
)
|
| 44 |
+
self.coefficients = coefficients
|
| 45 |
+
self.num_steps = num_steps
|
| 46 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 47 |
+
self.num_skip_start_steps = num_skip_start_steps
|
| 48 |
+
self.offload = offload
|
| 49 |
+
self.rescale_func = np.poly1d(self.coefficients)
|
| 50 |
+
|
| 51 |
+
self.cnt = 0
|
| 52 |
+
self.should_calc = True
|
| 53 |
+
self.accumulated_rel_l1_distance = 0
|
| 54 |
+
self.previous_modulated_input = None
|
| 55 |
+
# Some pipelines concatenate the unconditional and text guide in forward.
|
| 56 |
+
self.previous_residual = None
|
| 57 |
+
# Some pipelines perform forward propagation separately on the unconditional and text guide.
|
| 58 |
+
self.previous_residual_cond = None
|
| 59 |
+
self.previous_residual_uncond = None
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
|
| 64 |
+
|
| 65 |
+
return rel_l1_distance.cpu().item()
|
| 66 |
+
|
| 67 |
+
def reset(self):
|
| 68 |
+
self.cnt = 0
|
| 69 |
+
self.should_calc = True
|
| 70 |
+
self.accumulated_rel_l1_distance = 0
|
| 71 |
+
self.previous_modulated_input = None
|
| 72 |
+
self.previous_residual = None
|
| 73 |
+
self.previous_residual_cond = None
|
| 74 |
+
self.previous_residual_uncond = None
|
videox_fun/models/cogvideox_transformer3d.py
ADDED
|
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import glob
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.models.attention import Attention, FeedForward
|
| 25 |
+
from diffusers.models.attention_processor import (
|
| 26 |
+
AttentionProcessor, CogVideoXAttnProcessor2_0,
|
| 27 |
+
FusedCogVideoXAttnProcessor2_0)
|
| 28 |
+
from diffusers.models.embeddings import (CogVideoXPatchEmbed,
|
| 29 |
+
TimestepEmbedding, Timesteps,
|
| 30 |
+
get_3d_sincos_pos_embed)
|
| 31 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
| 34 |
+
from diffusers.utils import is_torch_version, logging
|
| 35 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 36 |
+
from torch import nn
|
| 37 |
+
|
| 38 |
+
from ..dist import (get_sequence_parallel_rank,
|
| 39 |
+
get_sequence_parallel_world_size,
|
| 40 |
+
get_sp_group,
|
| 41 |
+
xFuserLongContextAttention)
|
| 42 |
+
from ..dist.cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class CogVideoXPatchEmbed(nn.Module):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
patch_size: int = 2,
|
| 52 |
+
patch_size_t: Optional[int] = None,
|
| 53 |
+
in_channels: int = 16,
|
| 54 |
+
embed_dim: int = 1920,
|
| 55 |
+
text_embed_dim: int = 4096,
|
| 56 |
+
bias: bool = True,
|
| 57 |
+
sample_width: int = 90,
|
| 58 |
+
sample_height: int = 60,
|
| 59 |
+
sample_frames: int = 49,
|
| 60 |
+
temporal_compression_ratio: int = 4,
|
| 61 |
+
max_text_seq_length: int = 226,
|
| 62 |
+
spatial_interpolation_scale: float = 1.875,
|
| 63 |
+
temporal_interpolation_scale: float = 1.0,
|
| 64 |
+
use_positional_embeddings: bool = True,
|
| 65 |
+
use_learned_positional_embeddings: bool = True,
|
| 66 |
+
) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
post_patch_height = sample_height // patch_size
|
| 70 |
+
post_patch_width = sample_width // patch_size
|
| 71 |
+
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
| 72 |
+
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
| 73 |
+
self.post_patch_height = post_patch_height
|
| 74 |
+
self.post_patch_width = post_patch_width
|
| 75 |
+
self.post_time_compression_frames = post_time_compression_frames
|
| 76 |
+
self.patch_size = patch_size
|
| 77 |
+
self.patch_size_t = patch_size_t
|
| 78 |
+
self.embed_dim = embed_dim
|
| 79 |
+
self.sample_height = sample_height
|
| 80 |
+
self.sample_width = sample_width
|
| 81 |
+
self.sample_frames = sample_frames
|
| 82 |
+
self.temporal_compression_ratio = temporal_compression_ratio
|
| 83 |
+
self.max_text_seq_length = max_text_seq_length
|
| 84 |
+
self.spatial_interpolation_scale = spatial_interpolation_scale
|
| 85 |
+
self.temporal_interpolation_scale = temporal_interpolation_scale
|
| 86 |
+
self.use_positional_embeddings = use_positional_embeddings
|
| 87 |
+
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
| 88 |
+
|
| 89 |
+
if patch_size_t is None:
|
| 90 |
+
# CogVideoX 1.0 checkpoints
|
| 91 |
+
self.proj = nn.Conv2d(
|
| 92 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
# CogVideoX 1.5 checkpoints
|
| 96 |
+
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
|
| 97 |
+
|
| 98 |
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
| 99 |
+
|
| 100 |
+
if use_positional_embeddings or use_learned_positional_embeddings:
|
| 101 |
+
persistent = use_learned_positional_embeddings
|
| 102 |
+
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
| 103 |
+
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
| 104 |
+
|
| 105 |
+
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
| 106 |
+
post_patch_height = sample_height // self.patch_size
|
| 107 |
+
post_patch_width = sample_width // self.patch_size
|
| 108 |
+
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
| 109 |
+
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
| 110 |
+
|
| 111 |
+
pos_embedding = get_3d_sincos_pos_embed(
|
| 112 |
+
self.embed_dim,
|
| 113 |
+
(post_patch_width, post_patch_height),
|
| 114 |
+
post_time_compression_frames,
|
| 115 |
+
self.spatial_interpolation_scale,
|
| 116 |
+
self.temporal_interpolation_scale,
|
| 117 |
+
)
|
| 118 |
+
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
| 119 |
+
joint_pos_embedding = torch.zeros(
|
| 120 |
+
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
| 121 |
+
)
|
| 122 |
+
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
| 123 |
+
|
| 124 |
+
return joint_pos_embedding
|
| 125 |
+
|
| 126 |
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
| 127 |
+
r"""
|
| 128 |
+
Args:
|
| 129 |
+
text_embeds (`torch.Tensor`):
|
| 130 |
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
| 131 |
+
image_embeds (`torch.Tensor`):
|
| 132 |
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
| 133 |
+
"""
|
| 134 |
+
text_embeds = self.text_proj(text_embeds)
|
| 135 |
+
|
| 136 |
+
text_batch_size, text_seq_length, text_channels = text_embeds.shape
|
| 137 |
+
batch_size, num_frames, channels, height, width = image_embeds.shape
|
| 138 |
+
|
| 139 |
+
if self.patch_size_t is None:
|
| 140 |
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
| 141 |
+
image_embeds = self.proj(image_embeds)
|
| 142 |
+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
| 143 |
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
| 144 |
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
| 145 |
+
else:
|
| 146 |
+
p = self.patch_size
|
| 147 |
+
p_t = self.patch_size_t
|
| 148 |
+
|
| 149 |
+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
| 150 |
+
# b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
|
| 151 |
+
image_embeds = image_embeds.reshape(
|
| 152 |
+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
| 153 |
+
)
|
| 154 |
+
# b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
|
| 155 |
+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
| 156 |
+
image_embeds = self.proj(image_embeds)
|
| 157 |
+
|
| 158 |
+
embeds = torch.cat(
|
| 159 |
+
[text_embeds, image_embeds], dim=1
|
| 160 |
+
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
| 161 |
+
|
| 162 |
+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
| 163 |
+
seq_length = height * width * num_frames // (self.patch_size**2)
|
| 164 |
+
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
| 165 |
+
pos_embeds = self.pos_embedding
|
| 166 |
+
emb_size = embeds.size()[-1]
|
| 167 |
+
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
|
| 168 |
+
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
|
| 169 |
+
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
|
| 170 |
+
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
|
| 171 |
+
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
|
| 172 |
+
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
|
| 173 |
+
embeds = embeds + pos_embeds
|
| 174 |
+
|
| 175 |
+
return embeds
|
| 176 |
+
|
| 177 |
+
@maybe_allow_in_graph
|
| 178 |
+
class CogVideoXBlock(nn.Module):
|
| 179 |
+
r"""
|
| 180 |
+
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
| 181 |
+
|
| 182 |
+
Parameters:
|
| 183 |
+
dim (`int`):
|
| 184 |
+
The number of channels in the input and output.
|
| 185 |
+
num_attention_heads (`int`):
|
| 186 |
+
The number of heads to use for multi-head attention.
|
| 187 |
+
attention_head_dim (`int`):
|
| 188 |
+
The number of channels in each head.
|
| 189 |
+
time_embed_dim (`int`):
|
| 190 |
+
The number of channels in timestep embedding.
|
| 191 |
+
dropout (`float`, defaults to `0.0`):
|
| 192 |
+
The dropout probability to use.
|
| 193 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 194 |
+
Activation function to be used in feed-forward.
|
| 195 |
+
attention_bias (`bool`, defaults to `False`):
|
| 196 |
+
Whether or not to use bias in attention projection layers.
|
| 197 |
+
qk_norm (`bool`, defaults to `True`):
|
| 198 |
+
Whether or not to use normalization after query and key projections in Attention.
|
| 199 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 200 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 201 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 202 |
+
Epsilon value for normalization layers.
|
| 203 |
+
final_dropout (`bool` defaults to `False`):
|
| 204 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 205 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
| 206 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
| 207 |
+
ff_bias (`bool`, defaults to `True`):
|
| 208 |
+
Whether or not to use bias in Feed-forward layer.
|
| 209 |
+
attention_out_bias (`bool`, defaults to `True`):
|
| 210 |
+
Whether or not to use bias in Attention output projection layer.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
dim: int,
|
| 216 |
+
num_attention_heads: int,
|
| 217 |
+
attention_head_dim: int,
|
| 218 |
+
time_embed_dim: int,
|
| 219 |
+
dropout: float = 0.0,
|
| 220 |
+
activation_fn: str = "gelu-approximate",
|
| 221 |
+
attention_bias: bool = False,
|
| 222 |
+
qk_norm: bool = True,
|
| 223 |
+
norm_elementwise_affine: bool = True,
|
| 224 |
+
norm_eps: float = 1e-5,
|
| 225 |
+
final_dropout: bool = True,
|
| 226 |
+
ff_inner_dim: Optional[int] = None,
|
| 227 |
+
ff_bias: bool = True,
|
| 228 |
+
attention_out_bias: bool = True,
|
| 229 |
+
):
|
| 230 |
+
super().__init__()
|
| 231 |
+
|
| 232 |
+
# 1. Self Attention
|
| 233 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 234 |
+
|
| 235 |
+
self.attn1 = Attention(
|
| 236 |
+
query_dim=dim,
|
| 237 |
+
dim_head=attention_head_dim,
|
| 238 |
+
heads=num_attention_heads,
|
| 239 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 240 |
+
eps=1e-6,
|
| 241 |
+
bias=attention_bias,
|
| 242 |
+
out_bias=attention_out_bias,
|
| 243 |
+
processor=CogVideoXAttnProcessor2_0(),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# 2. Feed Forward
|
| 247 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 248 |
+
|
| 249 |
+
self.ff = FeedForward(
|
| 250 |
+
dim,
|
| 251 |
+
dropout=dropout,
|
| 252 |
+
activation_fn=activation_fn,
|
| 253 |
+
final_dropout=final_dropout,
|
| 254 |
+
inner_dim=ff_inner_dim,
|
| 255 |
+
bias=ff_bias,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
def forward(
|
| 259 |
+
self,
|
| 260 |
+
hidden_states: torch.Tensor,
|
| 261 |
+
encoder_hidden_states: torch.Tensor,
|
| 262 |
+
temb: torch.Tensor,
|
| 263 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 264 |
+
) -> torch.Tensor:
|
| 265 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 266 |
+
|
| 267 |
+
# norm & modulate
|
| 268 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
| 269 |
+
hidden_states, encoder_hidden_states, temb
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# attention
|
| 273 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
| 274 |
+
hidden_states=norm_hidden_states,
|
| 275 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 276 |
+
image_rotary_emb=image_rotary_emb,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
| 280 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
| 281 |
+
|
| 282 |
+
# norm & modulate
|
| 283 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
| 284 |
+
hidden_states, encoder_hidden_states, temb
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# feed-forward
|
| 288 |
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
| 289 |
+
ff_output = self.ff(norm_hidden_states)
|
| 290 |
+
|
| 291 |
+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
| 292 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
| 293 |
+
|
| 294 |
+
return hidden_states, encoder_hidden_states
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
| 298 |
+
"""
|
| 299 |
+
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
| 300 |
+
|
| 301 |
+
Parameters:
|
| 302 |
+
num_attention_heads (`int`, defaults to `30`):
|
| 303 |
+
The number of heads to use for multi-head attention.
|
| 304 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 305 |
+
The number of channels in each head.
|
| 306 |
+
in_channels (`int`, defaults to `16`):
|
| 307 |
+
The number of channels in the input.
|
| 308 |
+
out_channels (`int`, *optional*, defaults to `16`):
|
| 309 |
+
The number of channels in the output.
|
| 310 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 311 |
+
Whether to flip the sin to cos in the time embedding.
|
| 312 |
+
time_embed_dim (`int`, defaults to `512`):
|
| 313 |
+
Output dimension of timestep embeddings.
|
| 314 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 315 |
+
Input dimension of text embeddings from the text encoder.
|
| 316 |
+
num_layers (`int`, defaults to `30`):
|
| 317 |
+
The number of layers of Transformer blocks to use.
|
| 318 |
+
dropout (`float`, defaults to `0.0`):
|
| 319 |
+
The dropout probability to use.
|
| 320 |
+
attention_bias (`bool`, defaults to `True`):
|
| 321 |
+
Whether or not to use bias in the attention projection layers.
|
| 322 |
+
sample_width (`int`, defaults to `90`):
|
| 323 |
+
The width of the input latents.
|
| 324 |
+
sample_height (`int`, defaults to `60`):
|
| 325 |
+
The height of the input latents.
|
| 326 |
+
sample_frames (`int`, defaults to `49`):
|
| 327 |
+
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
| 328 |
+
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
| 329 |
+
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
| 330 |
+
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
| 331 |
+
patch_size (`int`, defaults to `2`):
|
| 332 |
+
The size of the patches to use in the patch embedding layer.
|
| 333 |
+
temporal_compression_ratio (`int`, defaults to `4`):
|
| 334 |
+
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
| 335 |
+
max_text_seq_length (`int`, defaults to `226`):
|
| 336 |
+
The maximum sequence length of the input text embeddings.
|
| 337 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 338 |
+
Activation function to use in feed-forward.
|
| 339 |
+
timestep_activation_fn (`str`, defaults to `"silu"`):
|
| 340 |
+
Activation function to use when generating the timestep embeddings.
|
| 341 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 342 |
+
Whether or not to use elementwise affine in normalization layers.
|
| 343 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 344 |
+
The epsilon value to use in normalization layers.
|
| 345 |
+
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
| 346 |
+
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
| 347 |
+
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
| 348 |
+
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
_supports_gradient_checkpointing = True
|
| 352 |
+
|
| 353 |
+
@register_to_config
|
| 354 |
+
def __init__(
|
| 355 |
+
self,
|
| 356 |
+
num_attention_heads: int = 30,
|
| 357 |
+
attention_head_dim: int = 64,
|
| 358 |
+
in_channels: int = 16,
|
| 359 |
+
out_channels: Optional[int] = 16,
|
| 360 |
+
flip_sin_to_cos: bool = True,
|
| 361 |
+
freq_shift: int = 0,
|
| 362 |
+
time_embed_dim: int = 512,
|
| 363 |
+
text_embed_dim: int = 4096,
|
| 364 |
+
num_layers: int = 30,
|
| 365 |
+
dropout: float = 0.0,
|
| 366 |
+
attention_bias: bool = True,
|
| 367 |
+
sample_width: int = 90,
|
| 368 |
+
sample_height: int = 60,
|
| 369 |
+
sample_frames: int = 49,
|
| 370 |
+
patch_size: int = 2,
|
| 371 |
+
patch_size_t: Optional[int] = None,
|
| 372 |
+
temporal_compression_ratio: int = 4,
|
| 373 |
+
max_text_seq_length: int = 226,
|
| 374 |
+
activation_fn: str = "gelu-approximate",
|
| 375 |
+
timestep_activation_fn: str = "silu",
|
| 376 |
+
norm_elementwise_affine: bool = True,
|
| 377 |
+
norm_eps: float = 1e-5,
|
| 378 |
+
spatial_interpolation_scale: float = 1.875,
|
| 379 |
+
temporal_interpolation_scale: float = 1.0,
|
| 380 |
+
use_rotary_positional_embeddings: bool = False,
|
| 381 |
+
use_learned_positional_embeddings: bool = False,
|
| 382 |
+
patch_bias: bool = True,
|
| 383 |
+
add_noise_in_inpaint_model: bool = False,
|
| 384 |
+
):
|
| 385 |
+
super().__init__()
|
| 386 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 387 |
+
self.patch_size_t = patch_size_t
|
| 388 |
+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
| 389 |
+
raise ValueError(
|
| 390 |
+
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
| 391 |
+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
| 392 |
+
"issue at https://github.com/huggingface/diffusers/issues."
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# 1. Patch embedding
|
| 396 |
+
self.patch_embed = CogVideoXPatchEmbed(
|
| 397 |
+
patch_size=patch_size,
|
| 398 |
+
patch_size_t=patch_size_t,
|
| 399 |
+
in_channels=in_channels,
|
| 400 |
+
embed_dim=inner_dim,
|
| 401 |
+
text_embed_dim=text_embed_dim,
|
| 402 |
+
bias=patch_bias,
|
| 403 |
+
sample_width=sample_width,
|
| 404 |
+
sample_height=sample_height,
|
| 405 |
+
sample_frames=sample_frames,
|
| 406 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 407 |
+
max_text_seq_length=max_text_seq_length,
|
| 408 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
| 409 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
| 410 |
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
| 411 |
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
| 412 |
+
)
|
| 413 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
| 414 |
+
|
| 415 |
+
# 2. Time embeddings
|
| 416 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
| 417 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
| 418 |
+
|
| 419 |
+
# 3. Define spatio-temporal transformers blocks
|
| 420 |
+
self.transformer_blocks = nn.ModuleList(
|
| 421 |
+
[
|
| 422 |
+
CogVideoXBlock(
|
| 423 |
+
dim=inner_dim,
|
| 424 |
+
num_attention_heads=num_attention_heads,
|
| 425 |
+
attention_head_dim=attention_head_dim,
|
| 426 |
+
time_embed_dim=time_embed_dim,
|
| 427 |
+
dropout=dropout,
|
| 428 |
+
activation_fn=activation_fn,
|
| 429 |
+
attention_bias=attention_bias,
|
| 430 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 431 |
+
norm_eps=norm_eps,
|
| 432 |
+
)
|
| 433 |
+
for _ in range(num_layers)
|
| 434 |
+
]
|
| 435 |
+
)
|
| 436 |
+
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
| 437 |
+
|
| 438 |
+
# 4. Output blocks
|
| 439 |
+
self.norm_out = AdaLayerNorm(
|
| 440 |
+
embedding_dim=time_embed_dim,
|
| 441 |
+
output_dim=2 * inner_dim,
|
| 442 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 443 |
+
norm_eps=norm_eps,
|
| 444 |
+
chunk_dim=1,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if patch_size_t is None:
|
| 448 |
+
# For CogVideox 1.0
|
| 449 |
+
output_dim = patch_size * patch_size * out_channels
|
| 450 |
+
else:
|
| 451 |
+
# For CogVideoX 1.5
|
| 452 |
+
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
| 453 |
+
|
| 454 |
+
self.proj_out = nn.Linear(inner_dim, output_dim)
|
| 455 |
+
|
| 456 |
+
self.gradient_checkpointing = False
|
| 457 |
+
self.sp_world_size = 1
|
| 458 |
+
self.sp_world_rank = 0
|
| 459 |
+
|
| 460 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 461 |
+
self.gradient_checkpointing = value
|
| 462 |
+
|
| 463 |
+
def enable_multi_gpus_inference(self,):
|
| 464 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 465 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 466 |
+
self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
|
| 467 |
+
|
| 468 |
+
@property
|
| 469 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 470 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 471 |
+
r"""
|
| 472 |
+
Returns:
|
| 473 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 474 |
+
indexed by its weight name.
|
| 475 |
+
"""
|
| 476 |
+
# set recursively
|
| 477 |
+
processors = {}
|
| 478 |
+
|
| 479 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 480 |
+
if hasattr(module, "get_processor"):
|
| 481 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 482 |
+
|
| 483 |
+
for sub_name, child in module.named_children():
|
| 484 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 485 |
+
|
| 486 |
+
return processors
|
| 487 |
+
|
| 488 |
+
for name, module in self.named_children():
|
| 489 |
+
fn_recursive_add_processors(name, module, processors)
|
| 490 |
+
|
| 491 |
+
return processors
|
| 492 |
+
|
| 493 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 494 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 495 |
+
r"""
|
| 496 |
+
Sets the attention processor to use to compute attention.
|
| 497 |
+
|
| 498 |
+
Parameters:
|
| 499 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 500 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 501 |
+
for **all** `Attention` layers.
|
| 502 |
+
|
| 503 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 504 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 505 |
+
|
| 506 |
+
"""
|
| 507 |
+
count = len(self.attn_processors.keys())
|
| 508 |
+
|
| 509 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 510 |
+
raise ValueError(
|
| 511 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 512 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 516 |
+
if hasattr(module, "set_processor"):
|
| 517 |
+
if not isinstance(processor, dict):
|
| 518 |
+
module.set_processor(processor)
|
| 519 |
+
else:
|
| 520 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 521 |
+
|
| 522 |
+
for sub_name, child in module.named_children():
|
| 523 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 524 |
+
|
| 525 |
+
for name, module in self.named_children():
|
| 526 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 527 |
+
|
| 528 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
| 529 |
+
def fuse_qkv_projections(self):
|
| 530 |
+
"""
|
| 531 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 532 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 533 |
+
|
| 534 |
+
<Tip warning={true}>
|
| 535 |
+
|
| 536 |
+
This API is 🧪 experimental.
|
| 537 |
+
|
| 538 |
+
</Tip>
|
| 539 |
+
"""
|
| 540 |
+
self.original_attn_processors = None
|
| 541 |
+
|
| 542 |
+
for _, attn_processor in self.attn_processors.items():
|
| 543 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 544 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 545 |
+
|
| 546 |
+
self.original_attn_processors = self.attn_processors
|
| 547 |
+
|
| 548 |
+
for module in self.modules():
|
| 549 |
+
if isinstance(module, Attention):
|
| 550 |
+
module.fuse_projections(fuse=True)
|
| 551 |
+
|
| 552 |
+
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
| 553 |
+
|
| 554 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 555 |
+
def unfuse_qkv_projections(self):
|
| 556 |
+
"""Disables the fused QKV projection if enabled.
|
| 557 |
+
|
| 558 |
+
<Tip warning={true}>
|
| 559 |
+
|
| 560 |
+
This API is 🧪 experimental.
|
| 561 |
+
|
| 562 |
+
</Tip>
|
| 563 |
+
|
| 564 |
+
"""
|
| 565 |
+
if self.original_attn_processors is not None:
|
| 566 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 567 |
+
|
| 568 |
+
def forward(
|
| 569 |
+
self,
|
| 570 |
+
hidden_states: torch.Tensor,
|
| 571 |
+
encoder_hidden_states: torch.Tensor,
|
| 572 |
+
timestep: Union[int, float, torch.LongTensor],
|
| 573 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 574 |
+
inpaint_latents: Optional[torch.Tensor] = None,
|
| 575 |
+
control_latents: Optional[torch.Tensor] = None,
|
| 576 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 577 |
+
return_dict: bool = True,
|
| 578 |
+
):
|
| 579 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
| 580 |
+
if num_frames == 1 and self.patch_size_t is not None:
|
| 581 |
+
hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
|
| 582 |
+
if inpaint_latents is not None:
|
| 583 |
+
inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
|
| 584 |
+
if control_latents is not None:
|
| 585 |
+
control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
|
| 586 |
+
local_num_frames = num_frames + 1
|
| 587 |
+
else:
|
| 588 |
+
local_num_frames = num_frames
|
| 589 |
+
|
| 590 |
+
# 1. Time embedding
|
| 591 |
+
timesteps = timestep
|
| 592 |
+
t_emb = self.time_proj(timesteps)
|
| 593 |
+
|
| 594 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 595 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 596 |
+
# there might be better ways to encapsulate this.
|
| 597 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
| 598 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 599 |
+
|
| 600 |
+
# 2. Patch embedding
|
| 601 |
+
if inpaint_latents is not None:
|
| 602 |
+
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
|
| 603 |
+
if control_latents is not None:
|
| 604 |
+
hidden_states = torch.concat([hidden_states, control_latents], 2)
|
| 605 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
| 606 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
| 607 |
+
|
| 608 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 609 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
| 610 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 611 |
+
|
| 612 |
+
# Context Parallel
|
| 613 |
+
if self.sp_world_size > 1:
|
| 614 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 615 |
+
if image_rotary_emb is not None:
|
| 616 |
+
image_rotary_emb = (
|
| 617 |
+
torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 618 |
+
torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# 3. Transformer blocks
|
| 622 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 623 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 624 |
+
|
| 625 |
+
def create_custom_forward(module):
|
| 626 |
+
def custom_forward(*inputs):
|
| 627 |
+
return module(*inputs)
|
| 628 |
+
|
| 629 |
+
return custom_forward
|
| 630 |
+
|
| 631 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 632 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 633 |
+
create_custom_forward(block),
|
| 634 |
+
hidden_states,
|
| 635 |
+
encoder_hidden_states,
|
| 636 |
+
emb,
|
| 637 |
+
image_rotary_emb,
|
| 638 |
+
**ckpt_kwargs,
|
| 639 |
+
)
|
| 640 |
+
else:
|
| 641 |
+
hidden_states, encoder_hidden_states = block(
|
| 642 |
+
hidden_states=hidden_states,
|
| 643 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 644 |
+
temb=emb,
|
| 645 |
+
image_rotary_emb=image_rotary_emb,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
if not self.config.use_rotary_positional_embeddings:
|
| 649 |
+
# CogVideoX-2B
|
| 650 |
+
hidden_states = self.norm_final(hidden_states)
|
| 651 |
+
else:
|
| 652 |
+
# CogVideoX-5B
|
| 653 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 654 |
+
hidden_states = self.norm_final(hidden_states)
|
| 655 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 656 |
+
|
| 657 |
+
# 4. Final block
|
| 658 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
| 659 |
+
hidden_states = self.proj_out(hidden_states)
|
| 660 |
+
|
| 661 |
+
if self.sp_world_size > 1:
|
| 662 |
+
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
| 663 |
+
|
| 664 |
+
# 5. Unpatchify
|
| 665 |
+
p = self.config.patch_size
|
| 666 |
+
p_t = self.config.patch_size_t
|
| 667 |
+
|
| 668 |
+
if p_t is None:
|
| 669 |
+
output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
|
| 670 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
| 671 |
+
else:
|
| 672 |
+
output = hidden_states.reshape(
|
| 673 |
+
batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
| 674 |
+
)
|
| 675 |
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
| 676 |
+
|
| 677 |
+
if num_frames == 1:
|
| 678 |
+
output = output[:, :num_frames, :]
|
| 679 |
+
|
| 680 |
+
if not return_dict:
|
| 681 |
+
return (output,)
|
| 682 |
+
return Transformer2DModelOutput(sample=output)
|
| 683 |
+
|
| 684 |
+
@classmethod
|
| 685 |
+
def from_pretrained(
|
| 686 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 687 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16, use_vae_mask=False, stack_mask=False,
|
| 688 |
+
):
|
| 689 |
+
if subfolder is not None:
|
| 690 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 691 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 692 |
+
|
| 693 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 694 |
+
if not os.path.isfile(config_file):
|
| 695 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 696 |
+
with open(config_file, "r") as f:
|
| 697 |
+
config = json.load(f)
|
| 698 |
+
|
| 699 |
+
if use_vae_mask:
|
| 700 |
+
print('[DEBUG] use vae to encode mask')
|
| 701 |
+
config['in_channels'] = 48
|
| 702 |
+
elif stack_mask:
|
| 703 |
+
print('[DEBUG] use stacking mask')
|
| 704 |
+
config['in_channels'] = 36
|
| 705 |
+
|
| 706 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 707 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 708 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 709 |
+
|
| 710 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 711 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 712 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 713 |
+
|
| 714 |
+
if low_cpu_mem_usage:
|
| 715 |
+
try:
|
| 716 |
+
import re
|
| 717 |
+
|
| 718 |
+
from diffusers.models.modeling_utils import \
|
| 719 |
+
load_model_dict_into_meta
|
| 720 |
+
from diffusers.utils import is_accelerate_available
|
| 721 |
+
if is_accelerate_available():
|
| 722 |
+
import accelerate
|
| 723 |
+
|
| 724 |
+
# Instantiate model with empty weights
|
| 725 |
+
with accelerate.init_empty_weights():
|
| 726 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 727 |
+
|
| 728 |
+
param_device = "cpu"
|
| 729 |
+
if os.path.exists(model_file):
|
| 730 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 731 |
+
elif os.path.exists(model_file_safetensors):
|
| 732 |
+
from safetensors.torch import load_file, safe_open
|
| 733 |
+
state_dict = load_file(model_file_safetensors)
|
| 734 |
+
else:
|
| 735 |
+
from safetensors.torch import load_file, safe_open
|
| 736 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 737 |
+
state_dict = {}
|
| 738 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 739 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 740 |
+
for key in _state_dict:
|
| 741 |
+
state_dict[key] = _state_dict[key]
|
| 742 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 743 |
+
# move the params from meta device to cpu
|
| 744 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 745 |
+
if len(missing_keys) > 0:
|
| 746 |
+
raise ValueError(
|
| 747 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
| 748 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 749 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 750 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 754 |
+
model,
|
| 755 |
+
state_dict,
|
| 756 |
+
device=param_device,
|
| 757 |
+
dtype=torch_dtype,
|
| 758 |
+
model_name_or_path=pretrained_model_path,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 762 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 763 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 764 |
+
|
| 765 |
+
if len(unexpected_keys) > 0:
|
| 766 |
+
print(
|
| 767 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 768 |
+
)
|
| 769 |
+
return model
|
| 770 |
+
except Exception as e:
|
| 771 |
+
print(
|
| 772 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 776 |
+
if os.path.exists(model_file):
|
| 777 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 778 |
+
elif os.path.exists(model_file_safetensors):
|
| 779 |
+
from safetensors.torch import load_file, safe_open
|
| 780 |
+
state_dict = load_file(model_file_safetensors)
|
| 781 |
+
else:
|
| 782 |
+
from safetensors.torch import load_file, safe_open
|
| 783 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 784 |
+
state_dict = {}
|
| 785 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 786 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 787 |
+
for key in _state_dict:
|
| 788 |
+
state_dict[key] = _state_dict[key]
|
| 789 |
+
|
| 790 |
+
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
|
| 791 |
+
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
|
| 792 |
+
if len(new_shape) == 5:
|
| 793 |
+
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
| 794 |
+
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
|
| 795 |
+
elif len(new_shape) == 2:
|
| 796 |
+
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
| 797 |
+
if use_vae_mask:
|
| 798 |
+
print('[DEBUG] patch_embed.proj.weight size does not match due to vae-encoded mask')
|
| 799 |
+
latent_ch = 16
|
| 800 |
+
feat_scale = 8
|
| 801 |
+
feat_dim = int(latent_ch * feat_scale)
|
| 802 |
+
old_total_dim = state_dict['patch_embed.proj.weight'].size(1)
|
| 803 |
+
new_total_dim = model.state_dict()['patch_embed.proj.weight'].size(1)
|
| 804 |
+
model.state_dict()['patch_embed.proj.weight'][:, :feat_dim] = state_dict['patch_embed.proj.weight'][:, :feat_dim]
|
| 805 |
+
model.state_dict()['patch_embed.proj.weight'][:, -feat_dim:] = state_dict['patch_embed.proj.weight'][:, -feat_dim:]
|
| 806 |
+
for i in range(feat_dim, new_total_dim - feat_dim, feat_scale):
|
| 807 |
+
model.state_dict()['patch_embed.proj.weight'][:, i:i+feat_scale] = state_dict['patch_embed.proj.weight'][:, feat_dim:-feat_dim]
|
| 808 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 809 |
+
else:
|
| 810 |
+
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
|
| 811 |
+
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
|
| 812 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 813 |
+
else:
|
| 814 |
+
model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
|
| 815 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 816 |
+
else:
|
| 817 |
+
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
| 818 |
+
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
|
| 819 |
+
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
|
| 820 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 821 |
+
else:
|
| 822 |
+
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
|
| 823 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 824 |
+
|
| 825 |
+
tmp_state_dict = {}
|
| 826 |
+
for key in state_dict:
|
| 827 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 828 |
+
tmp_state_dict[key] = state_dict[key]
|
| 829 |
+
else:
|
| 830 |
+
print(key, "Size don't match, skip")
|
| 831 |
+
|
| 832 |
+
state_dict = tmp_state_dict
|
| 833 |
+
|
| 834 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 835 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 836 |
+
print(m)
|
| 837 |
+
|
| 838 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 839 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 840 |
+
|
| 841 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 842 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 843 |
+
|
| 844 |
+
model = model.to(torch_dtype)
|
| 845 |
+
return model
|
videox_fun/models/cogvideox_vae.py
ADDED
|
@@ -0,0 +1,1675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 27 |
+
from diffusers.utils import logging
|
| 28 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 29 |
+
from diffusers.models.activations import get_activation
|
| 30 |
+
from diffusers.models.downsampling import CogVideoXDownsample3D
|
| 31 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.upsampling import CogVideoXUpsample3D
|
| 34 |
+
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CogVideoXSafeConv3d(nn.Conv3d):
|
| 41 |
+
r"""
|
| 42 |
+
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
memory_count = (
|
| 47 |
+
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Set to 2GB, suitable for CuDNN
|
| 51 |
+
if memory_count > 2:
|
| 52 |
+
kernel_size = self.kernel_size[0]
|
| 53 |
+
part_num = int(memory_count / 2) + 1
|
| 54 |
+
input_chunks = torch.chunk(input, part_num, dim=2)
|
| 55 |
+
|
| 56 |
+
if kernel_size > 1:
|
| 57 |
+
input_chunks = [input_chunks[0]] + [
|
| 58 |
+
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
| 59 |
+
for i in range(1, len(input_chunks))
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
output_chunks = []
|
| 63 |
+
for input_chunk in input_chunks:
|
| 64 |
+
output_chunks.append(super().forward(input_chunk))
|
| 65 |
+
output = torch.cat(output_chunks, dim=2)
|
| 66 |
+
return output
|
| 67 |
+
else:
|
| 68 |
+
return super().forward(input)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class CogVideoXCausalConv3d(nn.Module):
|
| 72 |
+
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
in_channels (`int`): Number of channels in the input tensor.
|
| 76 |
+
out_channels (`int`): Number of output channels produced by the convolution.
|
| 77 |
+
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
| 78 |
+
stride (`int`, defaults to `1`): Stride of the convolution.
|
| 79 |
+
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
| 80 |
+
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
in_channels: int,
|
| 86 |
+
out_channels: int,
|
| 87 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 88 |
+
stride: int = 1,
|
| 89 |
+
dilation: int = 1,
|
| 90 |
+
pad_mode: str = "constant",
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
|
| 94 |
+
if isinstance(kernel_size, int):
|
| 95 |
+
kernel_size = (kernel_size,) * 3
|
| 96 |
+
|
| 97 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
| 98 |
+
|
| 99 |
+
# TODO(aryan): configure calculation based on stride and dilation in the future.
|
| 100 |
+
# Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
|
| 101 |
+
time_pad = time_kernel_size - 1
|
| 102 |
+
height_pad = (height_kernel_size - 1) // 2
|
| 103 |
+
width_pad = (width_kernel_size - 1) // 2
|
| 104 |
+
|
| 105 |
+
self.pad_mode = pad_mode
|
| 106 |
+
self.height_pad = height_pad
|
| 107 |
+
self.width_pad = width_pad
|
| 108 |
+
self.time_pad = time_pad
|
| 109 |
+
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
| 110 |
+
|
| 111 |
+
self.temporal_dim = 2
|
| 112 |
+
self.time_kernel_size = time_kernel_size
|
| 113 |
+
|
| 114 |
+
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
| 115 |
+
dilation = (dilation, 1, 1)
|
| 116 |
+
self.conv = CogVideoXSafeConv3d(
|
| 117 |
+
in_channels=in_channels,
|
| 118 |
+
out_channels=out_channels,
|
| 119 |
+
kernel_size=kernel_size,
|
| 120 |
+
stride=stride,
|
| 121 |
+
dilation=dilation,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def fake_context_parallel_forward(
|
| 125 |
+
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
| 126 |
+
) -> torch.Tensor:
|
| 127 |
+
if self.pad_mode == "replicate":
|
| 128 |
+
inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
|
| 129 |
+
else:
|
| 130 |
+
kernel_size = self.time_kernel_size
|
| 131 |
+
if kernel_size > 1:
|
| 132 |
+
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
| 133 |
+
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
| 134 |
+
return inputs
|
| 135 |
+
|
| 136 |
+
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 137 |
+
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
| 138 |
+
|
| 139 |
+
if self.pad_mode == "replicate":
|
| 140 |
+
conv_cache = None
|
| 141 |
+
else:
|
| 142 |
+
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
| 143 |
+
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
| 144 |
+
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
| 145 |
+
|
| 146 |
+
output = self.conv(inputs)
|
| 147 |
+
return output, conv_cache
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class CogVideoXSpatialNorm3D(nn.Module):
|
| 151 |
+
r"""
|
| 152 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
|
| 153 |
+
to 3D-video like data.
|
| 154 |
+
|
| 155 |
+
CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
f_channels (`int`):
|
| 159 |
+
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
| 160 |
+
zq_channels (`int`):
|
| 161 |
+
The number of channels for the quantized vector as described in the paper.
|
| 162 |
+
groups (`int`):
|
| 163 |
+
Number of groups to separate the channels into for group normalization.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
f_channels: int,
|
| 169 |
+
zq_channels: int,
|
| 170 |
+
groups: int = 32,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
| 174 |
+
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
| 175 |
+
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
| 176 |
+
|
| 177 |
+
def forward(
|
| 178 |
+
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
new_conv_cache = {}
|
| 181 |
+
conv_cache = conv_cache or {}
|
| 182 |
+
|
| 183 |
+
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
| 184 |
+
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
| 185 |
+
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
| 186 |
+
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
| 187 |
+
z_first = F.interpolate(z_first, size=f_first_size)
|
| 188 |
+
z_rest = F.interpolate(z_rest, size=f_rest_size)
|
| 189 |
+
zq = torch.cat([z_first, z_rest], dim=2)
|
| 190 |
+
else:
|
| 191 |
+
zq = F.interpolate(zq, size=f.shape[-3:])
|
| 192 |
+
|
| 193 |
+
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
| 194 |
+
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
| 195 |
+
|
| 196 |
+
norm_f = self.norm_layer(f)
|
| 197 |
+
new_f = norm_f * conv_y + conv_b
|
| 198 |
+
return new_f, new_conv_cache
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class CogVideoXUpsample3D(nn.Module):
|
| 202 |
+
r"""
|
| 203 |
+
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
in_channels (`int`):
|
| 207 |
+
Number of channels in the input image.
|
| 208 |
+
out_channels (`int`):
|
| 209 |
+
Number of channels produced by the convolution.
|
| 210 |
+
kernel_size (`int`, defaults to `3`):
|
| 211 |
+
Size of the convolving kernel.
|
| 212 |
+
stride (`int`, defaults to `1`):
|
| 213 |
+
Stride of the convolution.
|
| 214 |
+
padding (`int`, defaults to `1`):
|
| 215 |
+
Padding added to all four sides of the input.
|
| 216 |
+
compress_time (`bool`, defaults to `False`):
|
| 217 |
+
Whether or not to compress the time dimension.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
in_channels: int,
|
| 223 |
+
out_channels: int,
|
| 224 |
+
kernel_size: int = 3,
|
| 225 |
+
stride: int = 1,
|
| 226 |
+
padding: int = 1,
|
| 227 |
+
compress_time: bool = False,
|
| 228 |
+
) -> None:
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
| 232 |
+
self.compress_time = compress_time
|
| 233 |
+
|
| 234 |
+
self.auto_split_process = True
|
| 235 |
+
self.first_frame_flag = False
|
| 236 |
+
|
| 237 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 238 |
+
if self.compress_time:
|
| 239 |
+
if self.auto_split_process:
|
| 240 |
+
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
| 241 |
+
# split first frame
|
| 242 |
+
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
| 243 |
+
|
| 244 |
+
x_first = F.interpolate(x_first, scale_factor=2.0)
|
| 245 |
+
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
| 246 |
+
x_first = x_first[:, :, None, :, :]
|
| 247 |
+
inputs = torch.cat([x_first, x_rest], dim=2)
|
| 248 |
+
elif inputs.shape[2] > 1:
|
| 249 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 250 |
+
else:
|
| 251 |
+
inputs = inputs.squeeze(2)
|
| 252 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 253 |
+
inputs = inputs[:, :, None, :, :]
|
| 254 |
+
else:
|
| 255 |
+
if self.first_frame_flag:
|
| 256 |
+
inputs = inputs.squeeze(2)
|
| 257 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 258 |
+
inputs = inputs[:, :, None, :, :]
|
| 259 |
+
else:
|
| 260 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 261 |
+
else:
|
| 262 |
+
# only interpolate 2D
|
| 263 |
+
b, c, t, h, w = inputs.shape
|
| 264 |
+
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 265 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 266 |
+
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
| 267 |
+
|
| 268 |
+
b, c, t, h, w = inputs.shape
|
| 269 |
+
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 270 |
+
inputs = self.conv(inputs)
|
| 271 |
+
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
| 272 |
+
|
| 273 |
+
return inputs
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class CogVideoXResnetBlock3D(nn.Module):
|
| 277 |
+
r"""
|
| 278 |
+
A 3D ResNet block used in the CogVideoX model.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
in_channels (`int`):
|
| 282 |
+
Number of input channels.
|
| 283 |
+
out_channels (`int`, *optional*):
|
| 284 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 285 |
+
dropout (`float`, defaults to `0.0`):
|
| 286 |
+
Dropout rate.
|
| 287 |
+
temb_channels (`int`, defaults to `512`):
|
| 288 |
+
Number of time embedding channels.
|
| 289 |
+
groups (`int`, defaults to `32`):
|
| 290 |
+
Number of groups to separate the channels into for group normalization.
|
| 291 |
+
eps (`float`, defaults to `1e-6`):
|
| 292 |
+
Epsilon value for normalization layers.
|
| 293 |
+
non_linearity (`str`, defaults to `"swish"`):
|
| 294 |
+
Activation function to use.
|
| 295 |
+
conv_shortcut (bool, defaults to `False`):
|
| 296 |
+
Whether or not to use a convolution shortcut.
|
| 297 |
+
spatial_norm_dim (`int`, *optional*):
|
| 298 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 299 |
+
pad_mode (str, defaults to `"first"`):
|
| 300 |
+
Padding mode.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
in_channels: int,
|
| 306 |
+
out_channels: Optional[int] = None,
|
| 307 |
+
dropout: float = 0.0,
|
| 308 |
+
temb_channels: int = 512,
|
| 309 |
+
groups: int = 32,
|
| 310 |
+
eps: float = 1e-6,
|
| 311 |
+
non_linearity: str = "swish",
|
| 312 |
+
conv_shortcut: bool = False,
|
| 313 |
+
spatial_norm_dim: Optional[int] = None,
|
| 314 |
+
pad_mode: str = "first",
|
| 315 |
+
):
|
| 316 |
+
super().__init__()
|
| 317 |
+
|
| 318 |
+
out_channels = out_channels or in_channels
|
| 319 |
+
|
| 320 |
+
self.in_channels = in_channels
|
| 321 |
+
self.out_channels = out_channels
|
| 322 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 323 |
+
self.use_conv_shortcut = conv_shortcut
|
| 324 |
+
self.spatial_norm_dim = spatial_norm_dim
|
| 325 |
+
|
| 326 |
+
if spatial_norm_dim is None:
|
| 327 |
+
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
| 328 |
+
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
| 329 |
+
else:
|
| 330 |
+
self.norm1 = CogVideoXSpatialNorm3D(
|
| 331 |
+
f_channels=in_channels,
|
| 332 |
+
zq_channels=spatial_norm_dim,
|
| 333 |
+
groups=groups,
|
| 334 |
+
)
|
| 335 |
+
self.norm2 = CogVideoXSpatialNorm3D(
|
| 336 |
+
f_channels=out_channels,
|
| 337 |
+
zq_channels=spatial_norm_dim,
|
| 338 |
+
groups=groups,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
self.conv1 = CogVideoXCausalConv3d(
|
| 342 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if temb_channels > 0:
|
| 346 |
+
self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
|
| 347 |
+
|
| 348 |
+
self.dropout = nn.Dropout(dropout)
|
| 349 |
+
self.conv2 = CogVideoXCausalConv3d(
|
| 350 |
+
in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if self.in_channels != self.out_channels:
|
| 354 |
+
if self.use_conv_shortcut:
|
| 355 |
+
self.conv_shortcut = CogVideoXCausalConv3d(
|
| 356 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
self.conv_shortcut = CogVideoXSafeConv3d(
|
| 360 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
def forward(
|
| 364 |
+
self,
|
| 365 |
+
inputs: torch.Tensor,
|
| 366 |
+
temb: Optional[torch.Tensor] = None,
|
| 367 |
+
zq: Optional[torch.Tensor] = None,
|
| 368 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
new_conv_cache = {}
|
| 371 |
+
conv_cache = conv_cache or {}
|
| 372 |
+
|
| 373 |
+
hidden_states = inputs
|
| 374 |
+
|
| 375 |
+
if zq is not None:
|
| 376 |
+
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
| 377 |
+
else:
|
| 378 |
+
hidden_states = self.norm1(hidden_states)
|
| 379 |
+
|
| 380 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 381 |
+
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
| 382 |
+
|
| 383 |
+
if temb is not None:
|
| 384 |
+
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
| 385 |
+
|
| 386 |
+
if zq is not None:
|
| 387 |
+
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
| 388 |
+
else:
|
| 389 |
+
hidden_states = self.norm2(hidden_states)
|
| 390 |
+
|
| 391 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 392 |
+
hidden_states = self.dropout(hidden_states)
|
| 393 |
+
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
| 394 |
+
|
| 395 |
+
if self.in_channels != self.out_channels:
|
| 396 |
+
if self.use_conv_shortcut:
|
| 397 |
+
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
| 398 |
+
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
inputs = self.conv_shortcut(inputs)
|
| 402 |
+
|
| 403 |
+
hidden_states = hidden_states + inputs
|
| 404 |
+
return hidden_states, new_conv_cache
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class CogVideoXDownBlock3D(nn.Module):
|
| 408 |
+
r"""
|
| 409 |
+
A downsampling block used in the CogVideoX model.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
in_channels (`int`):
|
| 413 |
+
Number of input channels.
|
| 414 |
+
out_channels (`int`, *optional*):
|
| 415 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 416 |
+
temb_channels (`int`, defaults to `512`):
|
| 417 |
+
Number of time embedding channels.
|
| 418 |
+
num_layers (`int`, defaults to `1`):
|
| 419 |
+
Number of resnet layers.
|
| 420 |
+
dropout (`float`, defaults to `0.0`):
|
| 421 |
+
Dropout rate.
|
| 422 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 423 |
+
Epsilon value for normalization layers.
|
| 424 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 425 |
+
Activation function to use.
|
| 426 |
+
resnet_groups (`int`, defaults to `32`):
|
| 427 |
+
Number of groups to separate the channels into for group normalization.
|
| 428 |
+
add_downsample (`bool`, defaults to `True`):
|
| 429 |
+
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
| 430 |
+
compress_time (`bool`, defaults to `False`):
|
| 431 |
+
Whether or not to downsample across temporal dimension.
|
| 432 |
+
pad_mode (str, defaults to `"first"`):
|
| 433 |
+
Padding mode.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
_supports_gradient_checkpointing = True
|
| 437 |
+
|
| 438 |
+
def __init__(
|
| 439 |
+
self,
|
| 440 |
+
in_channels: int,
|
| 441 |
+
out_channels: int,
|
| 442 |
+
temb_channels: int,
|
| 443 |
+
dropout: float = 0.0,
|
| 444 |
+
num_layers: int = 1,
|
| 445 |
+
resnet_eps: float = 1e-6,
|
| 446 |
+
resnet_act_fn: str = "swish",
|
| 447 |
+
resnet_groups: int = 32,
|
| 448 |
+
add_downsample: bool = True,
|
| 449 |
+
downsample_padding: int = 0,
|
| 450 |
+
compress_time: bool = False,
|
| 451 |
+
pad_mode: str = "first",
|
| 452 |
+
):
|
| 453 |
+
super().__init__()
|
| 454 |
+
|
| 455 |
+
resnets = []
|
| 456 |
+
for i in range(num_layers):
|
| 457 |
+
in_channel = in_channels if i == 0 else out_channels
|
| 458 |
+
resnets.append(
|
| 459 |
+
CogVideoXResnetBlock3D(
|
| 460 |
+
in_channels=in_channel,
|
| 461 |
+
out_channels=out_channels,
|
| 462 |
+
dropout=dropout,
|
| 463 |
+
temb_channels=temb_channels,
|
| 464 |
+
groups=resnet_groups,
|
| 465 |
+
eps=resnet_eps,
|
| 466 |
+
non_linearity=resnet_act_fn,
|
| 467 |
+
pad_mode=pad_mode,
|
| 468 |
+
)
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
self.resnets = nn.ModuleList(resnets)
|
| 472 |
+
self.downsamplers = None
|
| 473 |
+
|
| 474 |
+
if add_downsample:
|
| 475 |
+
self.downsamplers = nn.ModuleList(
|
| 476 |
+
[
|
| 477 |
+
CogVideoXDownsample3D(
|
| 478 |
+
out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
|
| 479 |
+
)
|
| 480 |
+
]
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
self.gradient_checkpointing = False
|
| 484 |
+
|
| 485 |
+
def forward(
|
| 486 |
+
self,
|
| 487 |
+
hidden_states: torch.Tensor,
|
| 488 |
+
temb: Optional[torch.Tensor] = None,
|
| 489 |
+
zq: Optional[torch.Tensor] = None,
|
| 490 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 491 |
+
) -> torch.Tensor:
|
| 492 |
+
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
| 493 |
+
|
| 494 |
+
new_conv_cache = {}
|
| 495 |
+
conv_cache = conv_cache or {}
|
| 496 |
+
|
| 497 |
+
for i, resnet in enumerate(self.resnets):
|
| 498 |
+
conv_cache_key = f"resnet_{i}"
|
| 499 |
+
|
| 500 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 501 |
+
|
| 502 |
+
def create_custom_forward(module):
|
| 503 |
+
def create_forward(*inputs):
|
| 504 |
+
return module(*inputs)
|
| 505 |
+
|
| 506 |
+
return create_forward
|
| 507 |
+
|
| 508 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 509 |
+
create_custom_forward(resnet),
|
| 510 |
+
hidden_states,
|
| 511 |
+
temb,
|
| 512 |
+
zq,
|
| 513 |
+
conv_cache.get(conv_cache_key),
|
| 514 |
+
)
|
| 515 |
+
else:
|
| 516 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 517 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if self.downsamplers is not None:
|
| 521 |
+
for downsampler in self.downsamplers:
|
| 522 |
+
hidden_states = downsampler(hidden_states)
|
| 523 |
+
|
| 524 |
+
return hidden_states, new_conv_cache
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class CogVideoXMidBlock3D(nn.Module):
|
| 528 |
+
r"""
|
| 529 |
+
A middle block used in the CogVideoX model.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
in_channels (`int`):
|
| 533 |
+
Number of input channels.
|
| 534 |
+
temb_channels (`int`, defaults to `512`):
|
| 535 |
+
Number of time embedding channels.
|
| 536 |
+
dropout (`float`, defaults to `0.0`):
|
| 537 |
+
Dropout rate.
|
| 538 |
+
num_layers (`int`, defaults to `1`):
|
| 539 |
+
Number of resnet layers.
|
| 540 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 541 |
+
Epsilon value for normalization layers.
|
| 542 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 543 |
+
Activation function to use.
|
| 544 |
+
resnet_groups (`int`, defaults to `32`):
|
| 545 |
+
Number of groups to separate the channels into for group normalization.
|
| 546 |
+
spatial_norm_dim (`int`, *optional*):
|
| 547 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 548 |
+
pad_mode (str, defaults to `"first"`):
|
| 549 |
+
Padding mode.
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
_supports_gradient_checkpointing = True
|
| 553 |
+
|
| 554 |
+
def __init__(
|
| 555 |
+
self,
|
| 556 |
+
in_channels: int,
|
| 557 |
+
temb_channels: int,
|
| 558 |
+
dropout: float = 0.0,
|
| 559 |
+
num_layers: int = 1,
|
| 560 |
+
resnet_eps: float = 1e-6,
|
| 561 |
+
resnet_act_fn: str = "swish",
|
| 562 |
+
resnet_groups: int = 32,
|
| 563 |
+
spatial_norm_dim: Optional[int] = None,
|
| 564 |
+
pad_mode: str = "first",
|
| 565 |
+
):
|
| 566 |
+
super().__init__()
|
| 567 |
+
|
| 568 |
+
resnets = []
|
| 569 |
+
for _ in range(num_layers):
|
| 570 |
+
resnets.append(
|
| 571 |
+
CogVideoXResnetBlock3D(
|
| 572 |
+
in_channels=in_channels,
|
| 573 |
+
out_channels=in_channels,
|
| 574 |
+
dropout=dropout,
|
| 575 |
+
temb_channels=temb_channels,
|
| 576 |
+
groups=resnet_groups,
|
| 577 |
+
eps=resnet_eps,
|
| 578 |
+
spatial_norm_dim=spatial_norm_dim,
|
| 579 |
+
non_linearity=resnet_act_fn,
|
| 580 |
+
pad_mode=pad_mode,
|
| 581 |
+
)
|
| 582 |
+
)
|
| 583 |
+
self.resnets = nn.ModuleList(resnets)
|
| 584 |
+
|
| 585 |
+
self.gradient_checkpointing = False
|
| 586 |
+
|
| 587 |
+
def forward(
|
| 588 |
+
self,
|
| 589 |
+
hidden_states: torch.Tensor,
|
| 590 |
+
temb: Optional[torch.Tensor] = None,
|
| 591 |
+
zq: Optional[torch.Tensor] = None,
|
| 592 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 593 |
+
) -> torch.Tensor:
|
| 594 |
+
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
| 595 |
+
|
| 596 |
+
new_conv_cache = {}
|
| 597 |
+
conv_cache = conv_cache or {}
|
| 598 |
+
|
| 599 |
+
for i, resnet in enumerate(self.resnets):
|
| 600 |
+
conv_cache_key = f"resnet_{i}"
|
| 601 |
+
|
| 602 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 603 |
+
|
| 604 |
+
def create_custom_forward(module):
|
| 605 |
+
def create_forward(*inputs):
|
| 606 |
+
return module(*inputs)
|
| 607 |
+
|
| 608 |
+
return create_forward
|
| 609 |
+
|
| 610 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 611 |
+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 615 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
return hidden_states, new_conv_cache
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class CogVideoXUpBlock3D(nn.Module):
|
| 622 |
+
r"""
|
| 623 |
+
An upsampling block used in the CogVideoX model.
|
| 624 |
+
|
| 625 |
+
Args:
|
| 626 |
+
in_channels (`int`):
|
| 627 |
+
Number of input channels.
|
| 628 |
+
out_channels (`int`, *optional*):
|
| 629 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 630 |
+
temb_channels (`int`, defaults to `512`):
|
| 631 |
+
Number of time embedding channels.
|
| 632 |
+
dropout (`float`, defaults to `0.0`):
|
| 633 |
+
Dropout rate.
|
| 634 |
+
num_layers (`int`, defaults to `1`):
|
| 635 |
+
Number of resnet layers.
|
| 636 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 637 |
+
Epsilon value for normalization layers.
|
| 638 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 639 |
+
Activation function to use.
|
| 640 |
+
resnet_groups (`int`, defaults to `32`):
|
| 641 |
+
Number of groups to separate the channels into for group normalization.
|
| 642 |
+
spatial_norm_dim (`int`, defaults to `16`):
|
| 643 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 644 |
+
add_upsample (`bool`, defaults to `True`):
|
| 645 |
+
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
|
| 646 |
+
compress_time (`bool`, defaults to `False`):
|
| 647 |
+
Whether or not to downsample across temporal dimension.
|
| 648 |
+
pad_mode (str, defaults to `"first"`):
|
| 649 |
+
Padding mode.
|
| 650 |
+
"""
|
| 651 |
+
|
| 652 |
+
def __init__(
|
| 653 |
+
self,
|
| 654 |
+
in_channels: int,
|
| 655 |
+
out_channels: int,
|
| 656 |
+
temb_channels: int,
|
| 657 |
+
dropout: float = 0.0,
|
| 658 |
+
num_layers: int = 1,
|
| 659 |
+
resnet_eps: float = 1e-6,
|
| 660 |
+
resnet_act_fn: str = "swish",
|
| 661 |
+
resnet_groups: int = 32,
|
| 662 |
+
spatial_norm_dim: int = 16,
|
| 663 |
+
add_upsample: bool = True,
|
| 664 |
+
upsample_padding: int = 1,
|
| 665 |
+
compress_time: bool = False,
|
| 666 |
+
pad_mode: str = "first",
|
| 667 |
+
):
|
| 668 |
+
super().__init__()
|
| 669 |
+
|
| 670 |
+
resnets = []
|
| 671 |
+
for i in range(num_layers):
|
| 672 |
+
in_channel = in_channels if i == 0 else out_channels
|
| 673 |
+
resnets.append(
|
| 674 |
+
CogVideoXResnetBlock3D(
|
| 675 |
+
in_channels=in_channel,
|
| 676 |
+
out_channels=out_channels,
|
| 677 |
+
dropout=dropout,
|
| 678 |
+
temb_channels=temb_channels,
|
| 679 |
+
groups=resnet_groups,
|
| 680 |
+
eps=resnet_eps,
|
| 681 |
+
non_linearity=resnet_act_fn,
|
| 682 |
+
spatial_norm_dim=spatial_norm_dim,
|
| 683 |
+
pad_mode=pad_mode,
|
| 684 |
+
)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
self.resnets = nn.ModuleList(resnets)
|
| 688 |
+
self.upsamplers = None
|
| 689 |
+
|
| 690 |
+
if add_upsample:
|
| 691 |
+
self.upsamplers = nn.ModuleList(
|
| 692 |
+
[
|
| 693 |
+
CogVideoXUpsample3D(
|
| 694 |
+
out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
|
| 695 |
+
)
|
| 696 |
+
]
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
self.gradient_checkpointing = False
|
| 700 |
+
|
| 701 |
+
def forward(
|
| 702 |
+
self,
|
| 703 |
+
hidden_states: torch.Tensor,
|
| 704 |
+
temb: Optional[torch.Tensor] = None,
|
| 705 |
+
zq: Optional[torch.Tensor] = None,
|
| 706 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 707 |
+
) -> torch.Tensor:
|
| 708 |
+
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
| 709 |
+
|
| 710 |
+
new_conv_cache = {}
|
| 711 |
+
conv_cache = conv_cache or {}
|
| 712 |
+
|
| 713 |
+
for i, resnet in enumerate(self.resnets):
|
| 714 |
+
conv_cache_key = f"resnet_{i}"
|
| 715 |
+
|
| 716 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 717 |
+
|
| 718 |
+
def create_custom_forward(module):
|
| 719 |
+
def create_forward(*inputs):
|
| 720 |
+
return module(*inputs)
|
| 721 |
+
|
| 722 |
+
return create_forward
|
| 723 |
+
|
| 724 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 725 |
+
create_custom_forward(resnet),
|
| 726 |
+
hidden_states,
|
| 727 |
+
temb,
|
| 728 |
+
zq,
|
| 729 |
+
conv_cache.get(conv_cache_key),
|
| 730 |
+
)
|
| 731 |
+
else:
|
| 732 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 733 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if self.upsamplers is not None:
|
| 737 |
+
for upsampler in self.upsamplers:
|
| 738 |
+
hidden_states = upsampler(hidden_states)
|
| 739 |
+
|
| 740 |
+
return hidden_states, new_conv_cache
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class CogVideoXEncoder3D(nn.Module):
|
| 744 |
+
r"""
|
| 745 |
+
The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 749 |
+
The number of input channels.
|
| 750 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 751 |
+
The number of output channels.
|
| 752 |
+
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
| 753 |
+
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
| 754 |
+
options.
|
| 755 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 756 |
+
The number of output channels for each block.
|
| 757 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
| 758 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
| 759 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 760 |
+
The number of layers per block.
|
| 761 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 762 |
+
The number of groups for normalization.
|
| 763 |
+
"""
|
| 764 |
+
|
| 765 |
+
_supports_gradient_checkpointing = True
|
| 766 |
+
|
| 767 |
+
def __init__(
|
| 768 |
+
self,
|
| 769 |
+
in_channels: int = 3,
|
| 770 |
+
out_channels: int = 16,
|
| 771 |
+
down_block_types: Tuple[str, ...] = (
|
| 772 |
+
"CogVideoXDownBlock3D",
|
| 773 |
+
"CogVideoXDownBlock3D",
|
| 774 |
+
"CogVideoXDownBlock3D",
|
| 775 |
+
"CogVideoXDownBlock3D",
|
| 776 |
+
),
|
| 777 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
| 778 |
+
layers_per_block: int = 3,
|
| 779 |
+
act_fn: str = "silu",
|
| 780 |
+
norm_eps: float = 1e-6,
|
| 781 |
+
norm_num_groups: int = 32,
|
| 782 |
+
dropout: float = 0.0,
|
| 783 |
+
pad_mode: str = "first",
|
| 784 |
+
temporal_compression_ratio: float = 4,
|
| 785 |
+
):
|
| 786 |
+
super().__init__()
|
| 787 |
+
|
| 788 |
+
# log2 of temporal_compress_times
|
| 789 |
+
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
| 790 |
+
|
| 791 |
+
self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
|
| 792 |
+
self.down_blocks = nn.ModuleList([])
|
| 793 |
+
|
| 794 |
+
# down blocks
|
| 795 |
+
output_channel = block_out_channels[0]
|
| 796 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 797 |
+
input_channel = output_channel
|
| 798 |
+
output_channel = block_out_channels[i]
|
| 799 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 800 |
+
compress_time = i < temporal_compress_level
|
| 801 |
+
|
| 802 |
+
if down_block_type == "CogVideoXDownBlock3D":
|
| 803 |
+
down_block = CogVideoXDownBlock3D(
|
| 804 |
+
in_channels=input_channel,
|
| 805 |
+
out_channels=output_channel,
|
| 806 |
+
temb_channels=0,
|
| 807 |
+
dropout=dropout,
|
| 808 |
+
num_layers=layers_per_block,
|
| 809 |
+
resnet_eps=norm_eps,
|
| 810 |
+
resnet_act_fn=act_fn,
|
| 811 |
+
resnet_groups=norm_num_groups,
|
| 812 |
+
add_downsample=not is_final_block,
|
| 813 |
+
compress_time=compress_time,
|
| 814 |
+
)
|
| 815 |
+
else:
|
| 816 |
+
raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
|
| 817 |
+
|
| 818 |
+
self.down_blocks.append(down_block)
|
| 819 |
+
|
| 820 |
+
# mid block
|
| 821 |
+
self.mid_block = CogVideoXMidBlock3D(
|
| 822 |
+
in_channels=block_out_channels[-1],
|
| 823 |
+
temb_channels=0,
|
| 824 |
+
dropout=dropout,
|
| 825 |
+
num_layers=2,
|
| 826 |
+
resnet_eps=norm_eps,
|
| 827 |
+
resnet_act_fn=act_fn,
|
| 828 |
+
resnet_groups=norm_num_groups,
|
| 829 |
+
pad_mode=pad_mode,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
|
| 833 |
+
self.conv_act = nn.SiLU()
|
| 834 |
+
self.conv_out = CogVideoXCausalConv3d(
|
| 835 |
+
block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
self.gradient_checkpointing = False
|
| 839 |
+
|
| 840 |
+
def forward(
|
| 841 |
+
self,
|
| 842 |
+
sample: torch.Tensor,
|
| 843 |
+
temb: Optional[torch.Tensor] = None,
|
| 844 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 845 |
+
) -> torch.Tensor:
|
| 846 |
+
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
| 847 |
+
|
| 848 |
+
new_conv_cache = {}
|
| 849 |
+
conv_cache = conv_cache or {}
|
| 850 |
+
|
| 851 |
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
| 852 |
+
|
| 853 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 854 |
+
|
| 855 |
+
def create_custom_forward(module):
|
| 856 |
+
def custom_forward(*inputs):
|
| 857 |
+
return module(*inputs)
|
| 858 |
+
|
| 859 |
+
return custom_forward
|
| 860 |
+
|
| 861 |
+
# 1. Down
|
| 862 |
+
for i, down_block in enumerate(self.down_blocks):
|
| 863 |
+
conv_cache_key = f"down_block_{i}"
|
| 864 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 865 |
+
create_custom_forward(down_block),
|
| 866 |
+
hidden_states,
|
| 867 |
+
temb,
|
| 868 |
+
None,
|
| 869 |
+
conv_cache.get(conv_cache_key),
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
# 2. Mid
|
| 873 |
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
| 874 |
+
create_custom_forward(self.mid_block),
|
| 875 |
+
hidden_states,
|
| 876 |
+
temb,
|
| 877 |
+
None,
|
| 878 |
+
conv_cache.get("mid_block"),
|
| 879 |
+
)
|
| 880 |
+
else:
|
| 881 |
+
# 1. Down
|
| 882 |
+
for i, down_block in enumerate(self.down_blocks):
|
| 883 |
+
conv_cache_key = f"down_block_{i}"
|
| 884 |
+
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
| 885 |
+
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# 2. Mid
|
| 889 |
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
| 890 |
+
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
# 3. Post-process
|
| 894 |
+
hidden_states = self.norm_out(hidden_states)
|
| 895 |
+
hidden_states = self.conv_act(hidden_states)
|
| 896 |
+
|
| 897 |
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
| 898 |
+
|
| 899 |
+
return hidden_states, new_conv_cache
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
class CogVideoXDecoder3D(nn.Module):
|
| 903 |
+
r"""
|
| 904 |
+
The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
|
| 905 |
+
sample.
|
| 906 |
+
|
| 907 |
+
Args:
|
| 908 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 909 |
+
The number of input channels.
|
| 910 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 911 |
+
The number of output channels.
|
| 912 |
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
| 913 |
+
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
| 914 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 915 |
+
The number of output channels for each block.
|
| 916 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
| 917 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
| 918 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 919 |
+
The number of layers per block.
|
| 920 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 921 |
+
The number of groups for normalization.
|
| 922 |
+
"""
|
| 923 |
+
|
| 924 |
+
_supports_gradient_checkpointing = True
|
| 925 |
+
|
| 926 |
+
def __init__(
|
| 927 |
+
self,
|
| 928 |
+
in_channels: int = 16,
|
| 929 |
+
out_channels: int = 3,
|
| 930 |
+
up_block_types: Tuple[str, ...] = (
|
| 931 |
+
"CogVideoXUpBlock3D",
|
| 932 |
+
"CogVideoXUpBlock3D",
|
| 933 |
+
"CogVideoXUpBlock3D",
|
| 934 |
+
"CogVideoXUpBlock3D",
|
| 935 |
+
),
|
| 936 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
| 937 |
+
layers_per_block: int = 3,
|
| 938 |
+
act_fn: str = "silu",
|
| 939 |
+
norm_eps: float = 1e-6,
|
| 940 |
+
norm_num_groups: int = 32,
|
| 941 |
+
dropout: float = 0.0,
|
| 942 |
+
pad_mode: str = "first",
|
| 943 |
+
temporal_compression_ratio: float = 4,
|
| 944 |
+
):
|
| 945 |
+
super().__init__()
|
| 946 |
+
|
| 947 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 948 |
+
|
| 949 |
+
self.conv_in = CogVideoXCausalConv3d(
|
| 950 |
+
in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
# mid block
|
| 954 |
+
self.mid_block = CogVideoXMidBlock3D(
|
| 955 |
+
in_channels=reversed_block_out_channels[0],
|
| 956 |
+
temb_channels=0,
|
| 957 |
+
num_layers=2,
|
| 958 |
+
resnet_eps=norm_eps,
|
| 959 |
+
resnet_act_fn=act_fn,
|
| 960 |
+
resnet_groups=norm_num_groups,
|
| 961 |
+
spatial_norm_dim=in_channels,
|
| 962 |
+
pad_mode=pad_mode,
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# up blocks
|
| 966 |
+
self.up_blocks = nn.ModuleList([])
|
| 967 |
+
|
| 968 |
+
output_channel = reversed_block_out_channels[0]
|
| 969 |
+
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
| 970 |
+
|
| 971 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 972 |
+
prev_output_channel = output_channel
|
| 973 |
+
output_channel = reversed_block_out_channels[i]
|
| 974 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 975 |
+
compress_time = i < temporal_compress_level
|
| 976 |
+
|
| 977 |
+
if up_block_type == "CogVideoXUpBlock3D":
|
| 978 |
+
up_block = CogVideoXUpBlock3D(
|
| 979 |
+
in_channels=prev_output_channel,
|
| 980 |
+
out_channels=output_channel,
|
| 981 |
+
temb_channels=0,
|
| 982 |
+
dropout=dropout,
|
| 983 |
+
num_layers=layers_per_block + 1,
|
| 984 |
+
resnet_eps=norm_eps,
|
| 985 |
+
resnet_act_fn=act_fn,
|
| 986 |
+
resnet_groups=norm_num_groups,
|
| 987 |
+
spatial_norm_dim=in_channels,
|
| 988 |
+
add_upsample=not is_final_block,
|
| 989 |
+
compress_time=compress_time,
|
| 990 |
+
pad_mode=pad_mode,
|
| 991 |
+
)
|
| 992 |
+
prev_output_channel = output_channel
|
| 993 |
+
else:
|
| 994 |
+
raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
|
| 995 |
+
|
| 996 |
+
self.up_blocks.append(up_block)
|
| 997 |
+
|
| 998 |
+
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
|
| 999 |
+
self.conv_act = nn.SiLU()
|
| 1000 |
+
self.conv_out = CogVideoXCausalConv3d(
|
| 1001 |
+
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
self.gradient_checkpointing = False
|
| 1005 |
+
|
| 1006 |
+
def forward(
|
| 1007 |
+
self,
|
| 1008 |
+
sample: torch.Tensor,
|
| 1009 |
+
temb: Optional[torch.Tensor] = None,
|
| 1010 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1011 |
+
) -> torch.Tensor:
|
| 1012 |
+
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
| 1013 |
+
|
| 1014 |
+
new_conv_cache = {}
|
| 1015 |
+
conv_cache = conv_cache or {}
|
| 1016 |
+
|
| 1017 |
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
| 1018 |
+
|
| 1019 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1020 |
+
|
| 1021 |
+
def create_custom_forward(module):
|
| 1022 |
+
def custom_forward(*inputs):
|
| 1023 |
+
return module(*inputs)
|
| 1024 |
+
|
| 1025 |
+
return custom_forward
|
| 1026 |
+
|
| 1027 |
+
# 1. Mid
|
| 1028 |
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
| 1029 |
+
create_custom_forward(self.mid_block),
|
| 1030 |
+
hidden_states,
|
| 1031 |
+
temb,
|
| 1032 |
+
sample,
|
| 1033 |
+
conv_cache.get("mid_block"),
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# 2. Up
|
| 1037 |
+
for i, up_block in enumerate(self.up_blocks):
|
| 1038 |
+
conv_cache_key = f"up_block_{i}"
|
| 1039 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 1040 |
+
create_custom_forward(up_block),
|
| 1041 |
+
hidden_states,
|
| 1042 |
+
temb,
|
| 1043 |
+
sample,
|
| 1044 |
+
conv_cache.get(conv_cache_key),
|
| 1045 |
+
)
|
| 1046 |
+
else:
|
| 1047 |
+
# 1. Mid
|
| 1048 |
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
| 1049 |
+
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
# 2. Up
|
| 1053 |
+
for i, up_block in enumerate(self.up_blocks):
|
| 1054 |
+
conv_cache_key = f"up_block_{i}"
|
| 1055 |
+
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
| 1056 |
+
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
# 3. Post-process
|
| 1060 |
+
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
| 1061 |
+
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
| 1062 |
+
)
|
| 1063 |
+
hidden_states = self.conv_act(hidden_states)
|
| 1064 |
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
| 1065 |
+
|
| 1066 |
+
return hidden_states, new_conv_cache
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 1070 |
+
r"""
|
| 1071 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
|
| 1072 |
+
[CogVideoX](https://github.com/THUDM/CogVideo).
|
| 1073 |
+
|
| 1074 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 1075 |
+
for all models (such as downloading or saving).
|
| 1076 |
+
|
| 1077 |
+
Parameters:
|
| 1078 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
| 1079 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
| 1080 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
| 1081 |
+
Tuple of downsample block types.
|
| 1082 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
| 1083 |
+
Tuple of upsample block types.
|
| 1084 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
| 1085 |
+
Tuple of block output channels.
|
| 1086 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 1087 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
| 1088 |
+
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
| 1089 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
| 1090 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
| 1091 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
| 1092 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
| 1093 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
| 1094 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
| 1095 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
| 1096 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
| 1097 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
| 1098 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
| 1099 |
+
"""
|
| 1100 |
+
|
| 1101 |
+
_supports_gradient_checkpointing = True
|
| 1102 |
+
_no_split_modules = ["CogVideoXResnetBlock3D"]
|
| 1103 |
+
|
| 1104 |
+
@register_to_config
|
| 1105 |
+
def __init__(
|
| 1106 |
+
self,
|
| 1107 |
+
in_channels: int = 3,
|
| 1108 |
+
out_channels: int = 3,
|
| 1109 |
+
down_block_types: Tuple[str] = (
|
| 1110 |
+
"CogVideoXDownBlock3D",
|
| 1111 |
+
"CogVideoXDownBlock3D",
|
| 1112 |
+
"CogVideoXDownBlock3D",
|
| 1113 |
+
"CogVideoXDownBlock3D",
|
| 1114 |
+
),
|
| 1115 |
+
up_block_types: Tuple[str] = (
|
| 1116 |
+
"CogVideoXUpBlock3D",
|
| 1117 |
+
"CogVideoXUpBlock3D",
|
| 1118 |
+
"CogVideoXUpBlock3D",
|
| 1119 |
+
"CogVideoXUpBlock3D",
|
| 1120 |
+
),
|
| 1121 |
+
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
| 1122 |
+
latent_channels: int = 16,
|
| 1123 |
+
layers_per_block: int = 3,
|
| 1124 |
+
act_fn: str = "silu",
|
| 1125 |
+
norm_eps: float = 1e-6,
|
| 1126 |
+
norm_num_groups: int = 32,
|
| 1127 |
+
temporal_compression_ratio: float = 4,
|
| 1128 |
+
sample_height: int = 480,
|
| 1129 |
+
sample_width: int = 720,
|
| 1130 |
+
scaling_factor: float = 1.15258426,
|
| 1131 |
+
shift_factor: Optional[float] = None,
|
| 1132 |
+
latents_mean: Optional[Tuple[float]] = None,
|
| 1133 |
+
latents_std: Optional[Tuple[float]] = None,
|
| 1134 |
+
force_upcast: float = True,
|
| 1135 |
+
use_quant_conv: bool = False,
|
| 1136 |
+
use_post_quant_conv: bool = False,
|
| 1137 |
+
invert_scale_latents: bool = False,
|
| 1138 |
+
):
|
| 1139 |
+
super().__init__()
|
| 1140 |
+
|
| 1141 |
+
self.encoder = CogVideoXEncoder3D(
|
| 1142 |
+
in_channels=in_channels,
|
| 1143 |
+
out_channels=latent_channels,
|
| 1144 |
+
down_block_types=down_block_types,
|
| 1145 |
+
block_out_channels=block_out_channels,
|
| 1146 |
+
layers_per_block=layers_per_block,
|
| 1147 |
+
act_fn=act_fn,
|
| 1148 |
+
norm_eps=norm_eps,
|
| 1149 |
+
norm_num_groups=norm_num_groups,
|
| 1150 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 1151 |
+
)
|
| 1152 |
+
self.decoder = CogVideoXDecoder3D(
|
| 1153 |
+
in_channels=latent_channels,
|
| 1154 |
+
out_channels=out_channels,
|
| 1155 |
+
up_block_types=up_block_types,
|
| 1156 |
+
block_out_channels=block_out_channels,
|
| 1157 |
+
layers_per_block=layers_per_block,
|
| 1158 |
+
act_fn=act_fn,
|
| 1159 |
+
norm_eps=norm_eps,
|
| 1160 |
+
norm_num_groups=norm_num_groups,
|
| 1161 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 1162 |
+
)
|
| 1163 |
+
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
|
| 1164 |
+
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
|
| 1165 |
+
|
| 1166 |
+
self.use_slicing = False
|
| 1167 |
+
self.use_tiling = False
|
| 1168 |
+
self.auto_split_process = False
|
| 1169 |
+
|
| 1170 |
+
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
| 1171 |
+
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
| 1172 |
+
# If you decode X latent frames together, the number of output frames is:
|
| 1173 |
+
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
| 1174 |
+
#
|
| 1175 |
+
# Example with num_latent_frames_batch_size = 2:
|
| 1176 |
+
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
| 1177 |
+
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
| 1178 |
+
# => 6 * 8 = 48 frames
|
| 1179 |
+
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
| 1180 |
+
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
| 1181 |
+
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
| 1182 |
+
# => 1 * 9 + 5 * 8 = 49 frames
|
| 1183 |
+
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
| 1184 |
+
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
| 1185 |
+
# number of temporal frames.
|
| 1186 |
+
self.num_latent_frames_batch_size = 2
|
| 1187 |
+
self.num_sample_frames_batch_size = 8
|
| 1188 |
+
|
| 1189 |
+
# We make the minimum height and width of sample for tiling half that of the generally supported
|
| 1190 |
+
self.tile_sample_min_height = sample_height // 2
|
| 1191 |
+
self.tile_sample_min_width = sample_width // 2
|
| 1192 |
+
self.tile_latent_min_height = int(
|
| 1193 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
| 1194 |
+
)
|
| 1195 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
| 1196 |
+
|
| 1197 |
+
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
| 1198 |
+
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
| 1199 |
+
# and so the tiling implementation has only been tested on those specific resolutions.
|
| 1200 |
+
self.tile_overlap_factor_height = 1 / 6
|
| 1201 |
+
self.tile_overlap_factor_width = 1 / 5
|
| 1202 |
+
|
| 1203 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1204 |
+
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
| 1205 |
+
module.gradient_checkpointing = value
|
| 1206 |
+
|
| 1207 |
+
def enable_tiling(
|
| 1208 |
+
self,
|
| 1209 |
+
tile_sample_min_height: Optional[int] = None,
|
| 1210 |
+
tile_sample_min_width: Optional[int] = None,
|
| 1211 |
+
tile_overlap_factor_height: Optional[float] = None,
|
| 1212 |
+
tile_overlap_factor_width: Optional[float] = None,
|
| 1213 |
+
) -> None:
|
| 1214 |
+
r"""
|
| 1215 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 1216 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 1217 |
+
processing larger images.
|
| 1218 |
+
|
| 1219 |
+
Args:
|
| 1220 |
+
tile_sample_min_height (`int`, *optional*):
|
| 1221 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
| 1222 |
+
tile_sample_min_width (`int`, *optional*):
|
| 1223 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
| 1224 |
+
tile_overlap_factor_height (`int`, *optional*):
|
| 1225 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
| 1226 |
+
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
| 1227 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
| 1228 |
+
tile_overlap_factor_width (`int`, *optional*):
|
| 1229 |
+
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
| 1230 |
+
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
| 1231 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
| 1232 |
+
"""
|
| 1233 |
+
self.use_tiling = True
|
| 1234 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
| 1235 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
| 1236 |
+
self.tile_latent_min_height = int(
|
| 1237 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
| 1238 |
+
)
|
| 1239 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
| 1240 |
+
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
| 1241 |
+
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
| 1242 |
+
|
| 1243 |
+
def disable_tiling(self) -> None:
|
| 1244 |
+
r"""
|
| 1245 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 1246 |
+
decoding in one step.
|
| 1247 |
+
"""
|
| 1248 |
+
self.use_tiling = False
|
| 1249 |
+
|
| 1250 |
+
def enable_slicing(self) -> None:
|
| 1251 |
+
r"""
|
| 1252 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 1253 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 1254 |
+
"""
|
| 1255 |
+
self.use_slicing = True
|
| 1256 |
+
|
| 1257 |
+
def disable_slicing(self) -> None:
|
| 1258 |
+
r"""
|
| 1259 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 1260 |
+
decoding in one step.
|
| 1261 |
+
"""
|
| 1262 |
+
self.use_slicing = False
|
| 1263 |
+
|
| 1264 |
+
def _set_first_frame(self):
|
| 1265 |
+
for name, module in self.named_modules():
|
| 1266 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1267 |
+
module.auto_split_process = False
|
| 1268 |
+
module.first_frame_flag = True
|
| 1269 |
+
|
| 1270 |
+
def _set_rest_frame(self):
|
| 1271 |
+
for name, module in self.named_modules():
|
| 1272 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1273 |
+
module.auto_split_process = False
|
| 1274 |
+
module.first_frame_flag = False
|
| 1275 |
+
|
| 1276 |
+
def enable_auto_split_process(self) -> None:
|
| 1277 |
+
self.auto_split_process = True
|
| 1278 |
+
for name, module in self.named_modules():
|
| 1279 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1280 |
+
module.auto_split_process = True
|
| 1281 |
+
|
| 1282 |
+
def disable_auto_split_process(self) -> None:
|
| 1283 |
+
self.auto_split_process = False
|
| 1284 |
+
|
| 1285 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 1286 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 1287 |
+
|
| 1288 |
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
| 1289 |
+
return self.tiled_encode(x)
|
| 1290 |
+
|
| 1291 |
+
frame_batch_size = self.num_sample_frames_batch_size
|
| 1292 |
+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
| 1293 |
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
| 1294 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1295 |
+
conv_cache = None
|
| 1296 |
+
enc = []
|
| 1297 |
+
|
| 1298 |
+
for i in range(num_batches):
|
| 1299 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1300 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
| 1301 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
| 1302 |
+
x_intermediate = x[:, :, start_frame:end_frame]
|
| 1303 |
+
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
| 1304 |
+
if self.quant_conv is not None:
|
| 1305 |
+
x_intermediate = self.quant_conv(x_intermediate)
|
| 1306 |
+
enc.append(x_intermediate)
|
| 1307 |
+
|
| 1308 |
+
enc = torch.cat(enc, dim=2)
|
| 1309 |
+
return enc
|
| 1310 |
+
|
| 1311 |
+
@apply_forward_hook
|
| 1312 |
+
def encode(
|
| 1313 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 1314 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 1315 |
+
"""
|
| 1316 |
+
Encode a batch of images into latents.
|
| 1317 |
+
|
| 1318 |
+
Args:
|
| 1319 |
+
x (`torch.Tensor`): Input batch of images.
|
| 1320 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1321 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 1322 |
+
|
| 1323 |
+
Returns:
|
| 1324 |
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 1325 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 1326 |
+
"""
|
| 1327 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 1328 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 1329 |
+
h = torch.cat(encoded_slices)
|
| 1330 |
+
else:
|
| 1331 |
+
h = self._encode(x)
|
| 1332 |
+
|
| 1333 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 1334 |
+
|
| 1335 |
+
if not return_dict:
|
| 1336 |
+
return (posterior,)
|
| 1337 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 1338 |
+
|
| 1339 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1340 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 1341 |
+
|
| 1342 |
+
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
| 1343 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 1344 |
+
|
| 1345 |
+
if self.auto_split_process:
|
| 1346 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
| 1347 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1348 |
+
conv_cache = None
|
| 1349 |
+
dec = []
|
| 1350 |
+
|
| 1351 |
+
for i in range(num_batches):
|
| 1352 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1353 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
| 1354 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
| 1355 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1356 |
+
if self.post_quant_conv is not None:
|
| 1357 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1358 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1359 |
+
dec.append(z_intermediate)
|
| 1360 |
+
else:
|
| 1361 |
+
conv_cache = None
|
| 1362 |
+
start_frame = 0
|
| 1363 |
+
end_frame = 1
|
| 1364 |
+
dec = []
|
| 1365 |
+
|
| 1366 |
+
self._set_first_frame()
|
| 1367 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1368 |
+
if self.post_quant_conv is not None:
|
| 1369 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1370 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1371 |
+
dec.append(z_intermediate)
|
| 1372 |
+
|
| 1373 |
+
self._set_rest_frame()
|
| 1374 |
+
start_frame = end_frame
|
| 1375 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1376 |
+
|
| 1377 |
+
while start_frame < num_frames:
|
| 1378 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1379 |
+
if self.post_quant_conv is not None:
|
| 1380 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1381 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1382 |
+
dec.append(z_intermediate)
|
| 1383 |
+
start_frame = end_frame
|
| 1384 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1385 |
+
|
| 1386 |
+
dec = torch.cat(dec, dim=2)
|
| 1387 |
+
|
| 1388 |
+
if not return_dict:
|
| 1389 |
+
return (dec,)
|
| 1390 |
+
|
| 1391 |
+
return DecoderOutput(sample=dec)
|
| 1392 |
+
|
| 1393 |
+
@apply_forward_hook
|
| 1394 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1395 |
+
"""
|
| 1396 |
+
Decode a batch of images.
|
| 1397 |
+
|
| 1398 |
+
Args:
|
| 1399 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1400 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1401 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1402 |
+
|
| 1403 |
+
Returns:
|
| 1404 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1405 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1406 |
+
returned.
|
| 1407 |
+
"""
|
| 1408 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 1409 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 1410 |
+
decoded = torch.cat(decoded_slices)
|
| 1411 |
+
else:
|
| 1412 |
+
decoded = self._decode(z).sample
|
| 1413 |
+
|
| 1414 |
+
if not return_dict:
|
| 1415 |
+
return (decoded,)
|
| 1416 |
+
return DecoderOutput(sample=decoded)
|
| 1417 |
+
|
| 1418 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1419 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
| 1420 |
+
for y in range(blend_extent):
|
| 1421 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 1422 |
+
y / blend_extent
|
| 1423 |
+
)
|
| 1424 |
+
return b
|
| 1425 |
+
|
| 1426 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1427 |
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
| 1428 |
+
for x in range(blend_extent):
|
| 1429 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 1430 |
+
x / blend_extent
|
| 1431 |
+
)
|
| 1432 |
+
return b
|
| 1433 |
+
|
| 1434 |
+
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 1435 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 1436 |
+
|
| 1437 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
| 1438 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
| 1439 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
| 1440 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
| 1441 |
+
output, but they should be much less noticeable.
|
| 1442 |
+
|
| 1443 |
+
Args:
|
| 1444 |
+
x (`torch.Tensor`): Input batch of videos.
|
| 1445 |
+
|
| 1446 |
+
Returns:
|
| 1447 |
+
`torch.Tensor`:
|
| 1448 |
+
The latent representation of the encoded videos.
|
| 1449 |
+
"""
|
| 1450 |
+
# For a rough memory estimate, take a look at the `tiled_decode` method.
|
| 1451 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 1452 |
+
|
| 1453 |
+
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
|
| 1454 |
+
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
|
| 1455 |
+
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
|
| 1456 |
+
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
|
| 1457 |
+
row_limit_height = self.tile_latent_min_height - blend_extent_height
|
| 1458 |
+
row_limit_width = self.tile_latent_min_width - blend_extent_width
|
| 1459 |
+
frame_batch_size = self.num_sample_frames_batch_size
|
| 1460 |
+
|
| 1461 |
+
# Split x into overlapping tiles and encode them separately.
|
| 1462 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1463 |
+
rows = []
|
| 1464 |
+
for i in range(0, height, overlap_height):
|
| 1465 |
+
row = []
|
| 1466 |
+
for j in range(0, width, overlap_width):
|
| 1467 |
+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
| 1468 |
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
| 1469 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1470 |
+
conv_cache = None
|
| 1471 |
+
time = []
|
| 1472 |
+
|
| 1473 |
+
for k in range(num_batches):
|
| 1474 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1475 |
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
| 1476 |
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
| 1477 |
+
tile = x[
|
| 1478 |
+
:,
|
| 1479 |
+
:,
|
| 1480 |
+
start_frame:end_frame,
|
| 1481 |
+
i : i + self.tile_sample_min_height,
|
| 1482 |
+
j : j + self.tile_sample_min_width,
|
| 1483 |
+
]
|
| 1484 |
+
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
| 1485 |
+
if self.quant_conv is not None:
|
| 1486 |
+
tile = self.quant_conv(tile)
|
| 1487 |
+
time.append(tile)
|
| 1488 |
+
|
| 1489 |
+
row.append(torch.cat(time, dim=2))
|
| 1490 |
+
rows.append(row)
|
| 1491 |
+
|
| 1492 |
+
result_rows = []
|
| 1493 |
+
for i, row in enumerate(rows):
|
| 1494 |
+
result_row = []
|
| 1495 |
+
for j, tile in enumerate(row):
|
| 1496 |
+
# blend the above tile and the left tile
|
| 1497 |
+
# to the current tile and add the current tile to the result row
|
| 1498 |
+
if i > 0:
|
| 1499 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
| 1500 |
+
if j > 0:
|
| 1501 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
| 1502 |
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
| 1503 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 1504 |
+
|
| 1505 |
+
enc = torch.cat(result_rows, dim=3)
|
| 1506 |
+
return enc
|
| 1507 |
+
|
| 1508 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1509 |
+
r"""
|
| 1510 |
+
Decode a batch of images using a tiled decoder.
|
| 1511 |
+
|
| 1512 |
+
Args:
|
| 1513 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1514 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1515 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1516 |
+
|
| 1517 |
+
Returns:
|
| 1518 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1519 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1520 |
+
returned.
|
| 1521 |
+
"""
|
| 1522 |
+
# Rough memory assessment:
|
| 1523 |
+
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
| 1524 |
+
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
| 1525 |
+
# - Assume fp16 (2 bytes per value).
|
| 1526 |
+
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
| 1527 |
+
#
|
| 1528 |
+
# Memory assessment when using tiling:
|
| 1529 |
+
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
| 1530 |
+
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
| 1531 |
+
|
| 1532 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 1533 |
+
|
| 1534 |
+
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
| 1535 |
+
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
| 1536 |
+
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
| 1537 |
+
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
| 1538 |
+
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
| 1539 |
+
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
| 1540 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
| 1541 |
+
|
| 1542 |
+
# Split z into overlapping tiles and decode them separately.
|
| 1543 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1544 |
+
rows = []
|
| 1545 |
+
for i in range(0, height, overlap_height):
|
| 1546 |
+
row = []
|
| 1547 |
+
for j in range(0, width, overlap_width):
|
| 1548 |
+
if self.auto_split_process:
|
| 1549 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1550 |
+
conv_cache = None
|
| 1551 |
+
time = []
|
| 1552 |
+
|
| 1553 |
+
for k in range(num_batches):
|
| 1554 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1555 |
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
| 1556 |
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
| 1557 |
+
tile = z[
|
| 1558 |
+
:,
|
| 1559 |
+
:,
|
| 1560 |
+
start_frame:end_frame,
|
| 1561 |
+
i : i + self.tile_latent_min_height,
|
| 1562 |
+
j : j + self.tile_latent_min_width,
|
| 1563 |
+
]
|
| 1564 |
+
if self.post_quant_conv is not None:
|
| 1565 |
+
tile = self.post_quant_conv(tile)
|
| 1566 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1567 |
+
time.append(tile)
|
| 1568 |
+
|
| 1569 |
+
row.append(torch.cat(time, dim=2))
|
| 1570 |
+
else:
|
| 1571 |
+
conv_cache = None
|
| 1572 |
+
start_frame = 0
|
| 1573 |
+
end_frame = 1
|
| 1574 |
+
dec = []
|
| 1575 |
+
|
| 1576 |
+
tile = z[
|
| 1577 |
+
:,
|
| 1578 |
+
:,
|
| 1579 |
+
start_frame:end_frame,
|
| 1580 |
+
i : i + self.tile_latent_min_height,
|
| 1581 |
+
j : j + self.tile_latent_min_width,
|
| 1582 |
+
]
|
| 1583 |
+
|
| 1584 |
+
self._set_first_frame()
|
| 1585 |
+
if self.post_quant_conv is not None:
|
| 1586 |
+
tile = self.post_quant_conv(tile)
|
| 1587 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1588 |
+
dec.append(tile)
|
| 1589 |
+
|
| 1590 |
+
self._set_rest_frame()
|
| 1591 |
+
start_frame = end_frame
|
| 1592 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1593 |
+
|
| 1594 |
+
while start_frame < num_frames:
|
| 1595 |
+
tile = z[
|
| 1596 |
+
:,
|
| 1597 |
+
:,
|
| 1598 |
+
start_frame:end_frame,
|
| 1599 |
+
i : i + self.tile_latent_min_height,
|
| 1600 |
+
j : j + self.tile_latent_min_width,
|
| 1601 |
+
]
|
| 1602 |
+
if self.post_quant_conv is not None:
|
| 1603 |
+
tile = self.post_quant_conv(tile)
|
| 1604 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1605 |
+
dec.append(tile)
|
| 1606 |
+
start_frame = end_frame
|
| 1607 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1608 |
+
|
| 1609 |
+
row.append(torch.cat(dec, dim=2))
|
| 1610 |
+
rows.append(row)
|
| 1611 |
+
|
| 1612 |
+
result_rows = []
|
| 1613 |
+
for i, row in enumerate(rows):
|
| 1614 |
+
result_row = []
|
| 1615 |
+
for j, tile in enumerate(row):
|
| 1616 |
+
# blend the above tile and the left tile
|
| 1617 |
+
# to the current tile and add the current tile to the result row
|
| 1618 |
+
if i > 0:
|
| 1619 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
| 1620 |
+
if j > 0:
|
| 1621 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
| 1622 |
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
| 1623 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 1624 |
+
|
| 1625 |
+
dec = torch.cat(result_rows, dim=3)
|
| 1626 |
+
|
| 1627 |
+
if not return_dict:
|
| 1628 |
+
return (dec,)
|
| 1629 |
+
|
| 1630 |
+
return DecoderOutput(sample=dec)
|
| 1631 |
+
|
| 1632 |
+
def forward(
|
| 1633 |
+
self,
|
| 1634 |
+
sample: torch.Tensor,
|
| 1635 |
+
sample_posterior: bool = False,
|
| 1636 |
+
return_dict: bool = True,
|
| 1637 |
+
generator: Optional[torch.Generator] = None,
|
| 1638 |
+
) -> Union[torch.Tensor, torch.Tensor]:
|
| 1639 |
+
x = sample
|
| 1640 |
+
posterior = self.encode(x).latent_dist
|
| 1641 |
+
if sample_posterior:
|
| 1642 |
+
z = posterior.sample(generator=generator)
|
| 1643 |
+
else:
|
| 1644 |
+
z = posterior.mode()
|
| 1645 |
+
dec = self.decode(z)
|
| 1646 |
+
if not return_dict:
|
| 1647 |
+
return (dec,)
|
| 1648 |
+
return dec
|
| 1649 |
+
|
| 1650 |
+
@classmethod
|
| 1651 |
+
def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
|
| 1652 |
+
if subfolder is not None:
|
| 1653 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 1654 |
+
|
| 1655 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 1656 |
+
if not os.path.isfile(config_file):
|
| 1657 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 1658 |
+
with open(config_file, "r") as f:
|
| 1659 |
+
config = json.load(f)
|
| 1660 |
+
|
| 1661 |
+
model = cls.from_config(config, **vae_additional_kwargs)
|
| 1662 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 1663 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 1664 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 1665 |
+
if os.path.exists(model_file_safetensors):
|
| 1666 |
+
from safetensors.torch import load_file, safe_open
|
| 1667 |
+
state_dict = load_file(model_file_safetensors)
|
| 1668 |
+
else:
|
| 1669 |
+
if not os.path.isfile(model_file):
|
| 1670 |
+
raise RuntimeError(f"{model_file} does not exist")
|
| 1671 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1672 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1673 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1674 |
+
print(m, u)
|
| 1675 |
+
return model
|
videox_fun/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline_cogvideox_fun import CogVideoXFunPipeline
|
| 2 |
+
from .pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
|
videox_fun/pipeline/pipeline_cogvideox_fun.py
ADDED
|
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 24 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 26 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 27 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
from diffusers.video_processor import VideoProcessor
|
| 30 |
+
|
| 31 |
+
from ..models import (AutoencoderKLCogVideoX,
|
| 32 |
+
CogVideoXTransformer3DModel, T5EncoderModel,
|
| 33 |
+
T5Tokenizer)
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
EXAMPLE_DOC_STRING = """
|
| 39 |
+
Examples:
|
| 40 |
+
```python
|
| 41 |
+
pass
|
| 42 |
+
```
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
|
| 47 |
+
def get_3d_rotary_pos_embed(
|
| 48 |
+
embed_dim,
|
| 49 |
+
crops_coords,
|
| 50 |
+
grid_size,
|
| 51 |
+
temporal_size,
|
| 52 |
+
theta: int = 10000,
|
| 53 |
+
use_real: bool = True,
|
| 54 |
+
grid_type: str = "linspace",
|
| 55 |
+
max_size: Optional[Tuple[int, int]] = None,
|
| 56 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 57 |
+
"""
|
| 58 |
+
RoPE for video tokens with 3D structure.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
embed_dim: (`int`):
|
| 62 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
| 63 |
+
crops_coords (`Tuple[int]`):
|
| 64 |
+
The top-left and bottom-right coordinates of the crop.
|
| 65 |
+
grid_size (`Tuple[int]`):
|
| 66 |
+
The grid size of the spatial positional embedding (height, width).
|
| 67 |
+
temporal_size (`int`):
|
| 68 |
+
The size of the temporal dimension.
|
| 69 |
+
theta (`float`):
|
| 70 |
+
Scaling factor for frequency computation.
|
| 71 |
+
grid_type (`str`):
|
| 72 |
+
Whether to use "linspace" or "slice" to compute grids.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
| 76 |
+
"""
|
| 77 |
+
if use_real is not True:
|
| 78 |
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
| 79 |
+
|
| 80 |
+
if grid_type == "linspace":
|
| 81 |
+
start, stop = crops_coords
|
| 82 |
+
grid_size_h, grid_size_w = grid_size
|
| 83 |
+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
| 84 |
+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
| 85 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 86 |
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
| 87 |
+
elif grid_type == "slice":
|
| 88 |
+
max_h, max_w = max_size
|
| 89 |
+
grid_size_h, grid_size_w = grid_size
|
| 90 |
+
grid_h = np.arange(max_h, dtype=np.float32)
|
| 91 |
+
grid_w = np.arange(max_w, dtype=np.float32)
|
| 92 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError("Invalid value passed for `grid_type`.")
|
| 95 |
+
|
| 96 |
+
# Compute dimensions for each axis
|
| 97 |
+
dim_t = embed_dim // 4
|
| 98 |
+
dim_h = embed_dim // 8 * 3
|
| 99 |
+
dim_w = embed_dim // 8 * 3
|
| 100 |
+
|
| 101 |
+
# Temporal frequencies
|
| 102 |
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
| 103 |
+
# Spatial frequencies for height and width
|
| 104 |
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
| 105 |
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
| 106 |
+
|
| 107 |
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
| 108 |
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
| 109 |
+
freqs_t = freqs_t[:, None, None, :].expand(
|
| 110 |
+
-1, grid_size_h, grid_size_w, -1
|
| 111 |
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
| 112 |
+
freqs_h = freqs_h[None, :, None, :].expand(
|
| 113 |
+
temporal_size, -1, grid_size_w, -1
|
| 114 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
| 115 |
+
freqs_w = freqs_w[None, None, :, :].expand(
|
| 116 |
+
temporal_size, grid_size_h, -1, -1
|
| 117 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
| 118 |
+
|
| 119 |
+
freqs = torch.cat(
|
| 120 |
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
| 121 |
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
| 122 |
+
freqs = freqs.view(
|
| 123 |
+
temporal_size * grid_size_h * grid_size_w, -1
|
| 124 |
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
| 125 |
+
return freqs
|
| 126 |
+
|
| 127 |
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
| 128 |
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
| 129 |
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
| 130 |
+
|
| 131 |
+
if grid_type == "slice":
|
| 132 |
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
| 133 |
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
| 134 |
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
| 135 |
+
|
| 136 |
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
| 137 |
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
| 138 |
+
return cos, sin
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 142 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 143 |
+
tw = tgt_width
|
| 144 |
+
th = tgt_height
|
| 145 |
+
h, w = src
|
| 146 |
+
r = h / w
|
| 147 |
+
if r > (th / tw):
|
| 148 |
+
resize_height = th
|
| 149 |
+
resize_width = int(round(th / h * w))
|
| 150 |
+
else:
|
| 151 |
+
resize_width = tw
|
| 152 |
+
resize_height = int(round(tw / w * h))
|
| 153 |
+
|
| 154 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 155 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 156 |
+
|
| 157 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 161 |
+
def retrieve_timesteps(
|
| 162 |
+
scheduler,
|
| 163 |
+
num_inference_steps: Optional[int] = None,
|
| 164 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 165 |
+
timesteps: Optional[List[int]] = None,
|
| 166 |
+
sigmas: Optional[List[float]] = None,
|
| 167 |
+
**kwargs,
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 171 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
scheduler (`SchedulerMixin`):
|
| 175 |
+
The scheduler to get timesteps from.
|
| 176 |
+
num_inference_steps (`int`):
|
| 177 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 178 |
+
must be `None`.
|
| 179 |
+
device (`str` or `torch.device`, *optional*):
|
| 180 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 181 |
+
timesteps (`List[int]`, *optional*):
|
| 182 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 183 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 184 |
+
sigmas (`List[float]`, *optional*):
|
| 185 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 186 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 190 |
+
second element is the number of inference steps.
|
| 191 |
+
"""
|
| 192 |
+
if timesteps is not None and sigmas is not None:
|
| 193 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 194 |
+
if timesteps is not None:
|
| 195 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 196 |
+
if not accepts_timesteps:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 199 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 200 |
+
)
|
| 201 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 202 |
+
timesteps = scheduler.timesteps
|
| 203 |
+
num_inference_steps = len(timesteps)
|
| 204 |
+
elif sigmas is not None:
|
| 205 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 206 |
+
if not accept_sigmas:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 209 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 210 |
+
)
|
| 211 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 212 |
+
timesteps = scheduler.timesteps
|
| 213 |
+
num_inference_steps = len(timesteps)
|
| 214 |
+
else:
|
| 215 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 216 |
+
timesteps = scheduler.timesteps
|
| 217 |
+
return timesteps, num_inference_steps
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@dataclass
|
| 221 |
+
class CogVideoXFunPipelineOutput(BaseOutput):
|
| 222 |
+
r"""
|
| 223 |
+
Output class for CogVideo pipelines.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 227 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 228 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 229 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
videos: torch.Tensor
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class CogVideoXFunPipeline(DiffusionPipeline):
|
| 236 |
+
r"""
|
| 237 |
+
Pipeline for text-to-video generation using CogVideoX_Fun.
|
| 238 |
+
|
| 239 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 240 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
vae ([`AutoencoderKL`]):
|
| 244 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 245 |
+
text_encoder ([`T5EncoderModel`]):
|
| 246 |
+
Frozen text-encoder. CogVideoX uses
|
| 247 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 248 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 249 |
+
tokenizer (`T5Tokenizer`):
|
| 250 |
+
Tokenizer of class
|
| 251 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 252 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 253 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 254 |
+
scheduler ([`SchedulerMixin`]):
|
| 255 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
_optional_components = []
|
| 259 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 260 |
+
|
| 261 |
+
_callback_tensor_inputs = [
|
| 262 |
+
"latents",
|
| 263 |
+
"prompt_embeds",
|
| 264 |
+
"negative_prompt_embeds",
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
tokenizer: T5Tokenizer,
|
| 270 |
+
text_encoder: T5EncoderModel,
|
| 271 |
+
vae: AutoencoderKLCogVideoX,
|
| 272 |
+
transformer: CogVideoXTransformer3DModel,
|
| 273 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 274 |
+
):
|
| 275 |
+
super().__init__()
|
| 276 |
+
|
| 277 |
+
self.register_modules(
|
| 278 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 279 |
+
)
|
| 280 |
+
self.vae_scale_factor_spatial = (
|
| 281 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 282 |
+
)
|
| 283 |
+
self.vae_scale_factor_temporal = (
|
| 284 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 288 |
+
|
| 289 |
+
def _get_t5_prompt_embeds(
|
| 290 |
+
self,
|
| 291 |
+
prompt: Union[str, List[str]] = None,
|
| 292 |
+
num_videos_per_prompt: int = 1,
|
| 293 |
+
max_sequence_length: int = 226,
|
| 294 |
+
device: Optional[torch.device] = None,
|
| 295 |
+
dtype: Optional[torch.dtype] = None,
|
| 296 |
+
):
|
| 297 |
+
device = device or self._execution_device
|
| 298 |
+
dtype = dtype or self.text_encoder.dtype
|
| 299 |
+
|
| 300 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 301 |
+
batch_size = len(prompt)
|
| 302 |
+
|
| 303 |
+
text_inputs = self.tokenizer(
|
| 304 |
+
prompt,
|
| 305 |
+
padding="max_length",
|
| 306 |
+
max_length=max_sequence_length,
|
| 307 |
+
truncation=True,
|
| 308 |
+
add_special_tokens=True,
|
| 309 |
+
return_tensors="pt",
|
| 310 |
+
)
|
| 311 |
+
text_input_ids = text_inputs.input_ids
|
| 312 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 313 |
+
|
| 314 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 315 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 316 |
+
logger.warning(
|
| 317 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 318 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 322 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 323 |
+
|
| 324 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 325 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 326 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 327 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 328 |
+
|
| 329 |
+
return prompt_embeds
|
| 330 |
+
|
| 331 |
+
def encode_prompt(
|
| 332 |
+
self,
|
| 333 |
+
prompt: Union[str, List[str]],
|
| 334 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 335 |
+
do_classifier_free_guidance: bool = True,
|
| 336 |
+
num_videos_per_prompt: int = 1,
|
| 337 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 338 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 339 |
+
max_sequence_length: int = 226,
|
| 340 |
+
device: Optional[torch.device] = None,
|
| 341 |
+
dtype: Optional[torch.dtype] = None,
|
| 342 |
+
):
|
| 343 |
+
r"""
|
| 344 |
+
Encodes the prompt into text encoder hidden states.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 348 |
+
prompt to be encoded
|
| 349 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 350 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 351 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 352 |
+
less than `1`).
|
| 353 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 354 |
+
Whether to use classifier free guidance or not.
|
| 355 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 356 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 357 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 358 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 359 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 360 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 361 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 362 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 363 |
+
argument.
|
| 364 |
+
device: (`torch.device`, *optional*):
|
| 365 |
+
torch device
|
| 366 |
+
dtype: (`torch.dtype`, *optional*):
|
| 367 |
+
torch dtype
|
| 368 |
+
"""
|
| 369 |
+
device = device or self._execution_device
|
| 370 |
+
|
| 371 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 372 |
+
if prompt is not None:
|
| 373 |
+
batch_size = len(prompt)
|
| 374 |
+
else:
|
| 375 |
+
batch_size = prompt_embeds.shape[0]
|
| 376 |
+
|
| 377 |
+
if prompt_embeds is None:
|
| 378 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 379 |
+
prompt=prompt,
|
| 380 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 381 |
+
max_sequence_length=max_sequence_length,
|
| 382 |
+
device=device,
|
| 383 |
+
dtype=dtype,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 387 |
+
negative_prompt = negative_prompt or ""
|
| 388 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 389 |
+
|
| 390 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 391 |
+
raise TypeError(
|
| 392 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 393 |
+
f" {type(prompt)}."
|
| 394 |
+
)
|
| 395 |
+
elif batch_size != len(negative_prompt):
|
| 396 |
+
raise ValueError(
|
| 397 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 398 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 399 |
+
" the batch size of `prompt`."
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 403 |
+
prompt=negative_prompt,
|
| 404 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 405 |
+
max_sequence_length=max_sequence_length,
|
| 406 |
+
device=device,
|
| 407 |
+
dtype=dtype,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
return prompt_embeds, negative_prompt_embeds
|
| 411 |
+
|
| 412 |
+
def prepare_latents(
|
| 413 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 414 |
+
):
|
| 415 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 416 |
+
raise ValueError(
|
| 417 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 418 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
shape = (
|
| 422 |
+
batch_size,
|
| 423 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 424 |
+
num_channels_latents,
|
| 425 |
+
height // self.vae_scale_factor_spatial,
|
| 426 |
+
width // self.vae_scale_factor_spatial,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
if latents is None:
|
| 430 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 431 |
+
else:
|
| 432 |
+
latents = latents.to(device)
|
| 433 |
+
|
| 434 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 435 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 436 |
+
return latents
|
| 437 |
+
|
| 438 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 439 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 440 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 441 |
+
|
| 442 |
+
frames = self.vae.decode(latents).sample
|
| 443 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 444 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 445 |
+
frames = frames.cpu().float().numpy()
|
| 446 |
+
return frames
|
| 447 |
+
|
| 448 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 449 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 450 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 451 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 452 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 453 |
+
# and should be between [0, 1]
|
| 454 |
+
|
| 455 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 456 |
+
extra_step_kwargs = {}
|
| 457 |
+
if accepts_eta:
|
| 458 |
+
extra_step_kwargs["eta"] = eta
|
| 459 |
+
|
| 460 |
+
# check if the scheduler accepts generator
|
| 461 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 462 |
+
if accepts_generator:
|
| 463 |
+
extra_step_kwargs["generator"] = generator
|
| 464 |
+
return extra_step_kwargs
|
| 465 |
+
|
| 466 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 467 |
+
def check_inputs(
|
| 468 |
+
self,
|
| 469 |
+
prompt,
|
| 470 |
+
height,
|
| 471 |
+
width,
|
| 472 |
+
negative_prompt,
|
| 473 |
+
callback_on_step_end_tensor_inputs,
|
| 474 |
+
prompt_embeds=None,
|
| 475 |
+
negative_prompt_embeds=None,
|
| 476 |
+
):
|
| 477 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 478 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 479 |
+
|
| 480 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 481 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 482 |
+
):
|
| 483 |
+
raise ValueError(
|
| 484 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 485 |
+
)
|
| 486 |
+
if prompt is not None and prompt_embeds is not None:
|
| 487 |
+
raise ValueError(
|
| 488 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 489 |
+
" only forward one of the two."
|
| 490 |
+
)
|
| 491 |
+
elif prompt is None and prompt_embeds is None:
|
| 492 |
+
raise ValueError(
|
| 493 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 494 |
+
)
|
| 495 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 496 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 497 |
+
|
| 498 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 499 |
+
raise ValueError(
|
| 500 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 501 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 507 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 511 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 514 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 515 |
+
f" {negative_prompt_embeds.shape}."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
def fuse_qkv_projections(self) -> None:
|
| 519 |
+
r"""Enables fused QKV projections."""
|
| 520 |
+
self.fusing_transformer = True
|
| 521 |
+
self.transformer.fuse_qkv_projections()
|
| 522 |
+
|
| 523 |
+
def unfuse_qkv_projections(self) -> None:
|
| 524 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 525 |
+
if not self.fusing_transformer:
|
| 526 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 527 |
+
else:
|
| 528 |
+
self.transformer.unfuse_qkv_projections()
|
| 529 |
+
self.fusing_transformer = False
|
| 530 |
+
|
| 531 |
+
def _prepare_rotary_positional_embeddings(
|
| 532 |
+
self,
|
| 533 |
+
height: int,
|
| 534 |
+
width: int,
|
| 535 |
+
num_frames: int,
|
| 536 |
+
device: torch.device,
|
| 537 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 538 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 539 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 540 |
+
|
| 541 |
+
p = self.transformer.config.patch_size
|
| 542 |
+
p_t = self.transformer.config.patch_size_t
|
| 543 |
+
|
| 544 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 545 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 546 |
+
|
| 547 |
+
if p_t is None:
|
| 548 |
+
# CogVideoX 1.0
|
| 549 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 550 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 551 |
+
)
|
| 552 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 553 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 554 |
+
crops_coords=grid_crops_coords,
|
| 555 |
+
grid_size=(grid_height, grid_width),
|
| 556 |
+
temporal_size=num_frames,
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
# CogVideoX 1.5
|
| 560 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 561 |
+
|
| 562 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 563 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 564 |
+
crops_coords=None,
|
| 565 |
+
grid_size=(grid_height, grid_width),
|
| 566 |
+
temporal_size=base_num_frames,
|
| 567 |
+
grid_type="slice",
|
| 568 |
+
max_size=(base_size_height, base_size_width),
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
freqs_cos = freqs_cos.to(device=device)
|
| 572 |
+
freqs_sin = freqs_sin.to(device=device)
|
| 573 |
+
return freqs_cos, freqs_sin
|
| 574 |
+
|
| 575 |
+
@property
|
| 576 |
+
def guidance_scale(self):
|
| 577 |
+
return self._guidance_scale
|
| 578 |
+
|
| 579 |
+
@property
|
| 580 |
+
def num_timesteps(self):
|
| 581 |
+
return self._num_timesteps
|
| 582 |
+
|
| 583 |
+
@property
|
| 584 |
+
def attention_kwargs(self):
|
| 585 |
+
return self._attention_kwargs
|
| 586 |
+
|
| 587 |
+
@property
|
| 588 |
+
def interrupt(self):
|
| 589 |
+
return self._interrupt
|
| 590 |
+
|
| 591 |
+
@torch.no_grad()
|
| 592 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 593 |
+
def __call__(
|
| 594 |
+
self,
|
| 595 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 596 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 597 |
+
height: int = 480,
|
| 598 |
+
width: int = 720,
|
| 599 |
+
num_frames: int = 49,
|
| 600 |
+
num_inference_steps: int = 50,
|
| 601 |
+
timesteps: Optional[List[int]] = None,
|
| 602 |
+
guidance_scale: float = 6,
|
| 603 |
+
use_dynamic_cfg: bool = False,
|
| 604 |
+
num_videos_per_prompt: int = 1,
|
| 605 |
+
eta: float = 0.0,
|
| 606 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 607 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 608 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 609 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 610 |
+
output_type: str = "numpy",
|
| 611 |
+
return_dict: bool = False,
|
| 612 |
+
callback_on_step_end: Optional[
|
| 613 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 614 |
+
] = None,
|
| 615 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 616 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 617 |
+
max_sequence_length: int = 226,
|
| 618 |
+
) -> Union[CogVideoXFunPipelineOutput, Tuple]:
|
| 619 |
+
"""
|
| 620 |
+
Function invoked when calling the pipeline for generation.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 624 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 625 |
+
instead.
|
| 626 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 627 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 628 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 629 |
+
less than `1`).
|
| 630 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 631 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 632 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 633 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 634 |
+
num_frames (`int`, defaults to `48`):
|
| 635 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 636 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 637 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 638 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 639 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 640 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 641 |
+
expense of slower inference.
|
| 642 |
+
timesteps (`List[int]`, *optional*):
|
| 643 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 644 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 645 |
+
passed will be used. Must be in descending order.
|
| 646 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 647 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 648 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 649 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 650 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 651 |
+
usually at the expense of lower image quality.
|
| 652 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 653 |
+
The number of videos to generate per prompt.
|
| 654 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 655 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 656 |
+
to make generation deterministic.
|
| 657 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 658 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 659 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 660 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 661 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 662 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 663 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 664 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 665 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 666 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 667 |
+
argument.
|
| 668 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 669 |
+
The output format of the generate image. Choose between
|
| 670 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 671 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 672 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 673 |
+
of a plain tuple.
|
| 674 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 675 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 676 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 677 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 678 |
+
`callback_on_step_end_tensor_inputs`.
|
| 679 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 680 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 681 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 682 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 683 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 684 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 685 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 686 |
+
|
| 687 |
+
Examples:
|
| 688 |
+
|
| 689 |
+
Returns:
|
| 690 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
|
| 691 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
|
| 692 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 696 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 697 |
+
|
| 698 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 699 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 700 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 701 |
+
|
| 702 |
+
num_videos_per_prompt = 1
|
| 703 |
+
|
| 704 |
+
# 1. Check inputs. Raise error if not correct
|
| 705 |
+
self.check_inputs(
|
| 706 |
+
prompt,
|
| 707 |
+
height,
|
| 708 |
+
width,
|
| 709 |
+
negative_prompt,
|
| 710 |
+
callback_on_step_end_tensor_inputs,
|
| 711 |
+
prompt_embeds,
|
| 712 |
+
negative_prompt_embeds,
|
| 713 |
+
)
|
| 714 |
+
self._guidance_scale = guidance_scale
|
| 715 |
+
self._attention_kwargs = attention_kwargs
|
| 716 |
+
self._interrupt = False
|
| 717 |
+
|
| 718 |
+
# 2. Default call parameters
|
| 719 |
+
if prompt is not None and isinstance(prompt, str):
|
| 720 |
+
batch_size = 1
|
| 721 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 722 |
+
batch_size = len(prompt)
|
| 723 |
+
else:
|
| 724 |
+
batch_size = prompt_embeds.shape[0]
|
| 725 |
+
|
| 726 |
+
device = self._execution_device
|
| 727 |
+
|
| 728 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 729 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 730 |
+
# corresponds to doing no classifier free guidance.
|
| 731 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 732 |
+
|
| 733 |
+
# 3. Encode input prompt
|
| 734 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 735 |
+
prompt,
|
| 736 |
+
negative_prompt,
|
| 737 |
+
do_classifier_free_guidance,
|
| 738 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 739 |
+
prompt_embeds=prompt_embeds,
|
| 740 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 741 |
+
max_sequence_length=max_sequence_length,
|
| 742 |
+
device=device,
|
| 743 |
+
)
|
| 744 |
+
if do_classifier_free_guidance:
|
| 745 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 746 |
+
|
| 747 |
+
# 4. Prepare timesteps
|
| 748 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 749 |
+
self._num_timesteps = len(timesteps)
|
| 750 |
+
|
| 751 |
+
# 5. Prepare latents
|
| 752 |
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 753 |
+
|
| 754 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
| 755 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 756 |
+
additional_frames = 0
|
| 757 |
+
if num_frames != 1 and patch_size_t is not None and latent_frames % patch_size_t != 0:
|
| 758 |
+
additional_frames = patch_size_t - latent_frames % patch_size_t
|
| 759 |
+
num_frames += additional_frames * self.vae_scale_factor_temporal
|
| 760 |
+
|
| 761 |
+
latent_channels = self.transformer.config.in_channels
|
| 762 |
+
latents = self.prepare_latents(
|
| 763 |
+
batch_size * num_videos_per_prompt,
|
| 764 |
+
latent_channels,
|
| 765 |
+
num_frames,
|
| 766 |
+
height,
|
| 767 |
+
width,
|
| 768 |
+
prompt_embeds.dtype,
|
| 769 |
+
device,
|
| 770 |
+
generator,
|
| 771 |
+
latents,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 775 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 776 |
+
|
| 777 |
+
# 7. Create rotary embeds if required
|
| 778 |
+
image_rotary_emb = (
|
| 779 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 780 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 781 |
+
else None
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# 8. Denoising loop
|
| 785 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 786 |
+
|
| 787 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 788 |
+
# for DPM-solver++
|
| 789 |
+
old_pred_original_sample = None
|
| 790 |
+
for i, t in enumerate(timesteps):
|
| 791 |
+
if self.interrupt:
|
| 792 |
+
continue
|
| 793 |
+
|
| 794 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 795 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 796 |
+
|
| 797 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 798 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 799 |
+
|
| 800 |
+
# predict noise model_output
|
| 801 |
+
noise_pred = self.transformer(
|
| 802 |
+
hidden_states=latent_model_input,
|
| 803 |
+
encoder_hidden_states=prompt_embeds,
|
| 804 |
+
timestep=timestep,
|
| 805 |
+
image_rotary_emb=image_rotary_emb,
|
| 806 |
+
return_dict=False,
|
| 807 |
+
)[0]
|
| 808 |
+
noise_pred = noise_pred.float()
|
| 809 |
+
|
| 810 |
+
# perform guidance
|
| 811 |
+
if use_dynamic_cfg:
|
| 812 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 813 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 814 |
+
)
|
| 815 |
+
if do_classifier_free_guidance:
|
| 816 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 817 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 818 |
+
|
| 819 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 820 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 821 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 822 |
+
else:
|
| 823 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 824 |
+
noise_pred,
|
| 825 |
+
old_pred_original_sample,
|
| 826 |
+
t,
|
| 827 |
+
timesteps[i - 1] if i > 0 else None,
|
| 828 |
+
latents,
|
| 829 |
+
**extra_step_kwargs,
|
| 830 |
+
return_dict=False,
|
| 831 |
+
)
|
| 832 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 833 |
+
|
| 834 |
+
# call the callback, if provided
|
| 835 |
+
if callback_on_step_end is not None:
|
| 836 |
+
callback_kwargs = {}
|
| 837 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 838 |
+
callback_kwargs[k] = locals()[k]
|
| 839 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 840 |
+
|
| 841 |
+
latents = callback_outputs.pop("latents", latents)
|
| 842 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 843 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 844 |
+
|
| 845 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 846 |
+
progress_bar.update()
|
| 847 |
+
|
| 848 |
+
if output_type == "numpy":
|
| 849 |
+
video = self.decode_latents(latents)
|
| 850 |
+
elif not output_type == "latent":
|
| 851 |
+
video = self.decode_latents(latents)
|
| 852 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 853 |
+
else:
|
| 854 |
+
video = latents
|
| 855 |
+
|
| 856 |
+
# Offload all models
|
| 857 |
+
self.maybe_free_model_hooks()
|
| 858 |
+
|
| 859 |
+
if not return_dict:
|
| 860 |
+
video = torch.from_numpy(video)
|
| 861 |
+
|
| 862 |
+
return CogVideoXFunPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py
ADDED
|
@@ -0,0 +1,1244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 26 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 27 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 28 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 29 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 30 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 31 |
+
from diffusers.video_processor import VideoProcessor
|
| 32 |
+
from einops import rearrange
|
| 33 |
+
|
| 34 |
+
from ..models import (AutoencoderKLCogVideoX,
|
| 35 |
+
CogVideoXTransformer3DModel, T5EncoderModel,
|
| 36 |
+
T5Tokenizer)
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
EXAMPLE_DOC_STRING = """
|
| 42 |
+
Examples:
|
| 43 |
+
```python
|
| 44 |
+
pass
|
| 45 |
+
```
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
|
| 49 |
+
def get_3d_rotary_pos_embed(
|
| 50 |
+
embed_dim,
|
| 51 |
+
crops_coords,
|
| 52 |
+
grid_size,
|
| 53 |
+
temporal_size,
|
| 54 |
+
theta: int = 10000,
|
| 55 |
+
use_real: bool = True,
|
| 56 |
+
grid_type: str = "linspace",
|
| 57 |
+
max_size: Optional[Tuple[int, int]] = None,
|
| 58 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 59 |
+
"""
|
| 60 |
+
RoPE for video tokens with 3D structure.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
embed_dim: (`int`):
|
| 64 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
| 65 |
+
crops_coords (`Tuple[int]`):
|
| 66 |
+
The top-left and bottom-right coordinates of the crop.
|
| 67 |
+
grid_size (`Tuple[int]`):
|
| 68 |
+
The grid size of the spatial positional embedding (height, width).
|
| 69 |
+
temporal_size (`int`):
|
| 70 |
+
The size of the temporal dimension.
|
| 71 |
+
theta (`float`):
|
| 72 |
+
Scaling factor for frequency computation.
|
| 73 |
+
grid_type (`str`):
|
| 74 |
+
Whether to use "linspace" or "slice" to compute grids.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
| 78 |
+
"""
|
| 79 |
+
if use_real is not True:
|
| 80 |
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
| 81 |
+
|
| 82 |
+
if grid_type == "linspace":
|
| 83 |
+
start, stop = crops_coords
|
| 84 |
+
grid_size_h, grid_size_w = grid_size
|
| 85 |
+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
| 86 |
+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
| 87 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 88 |
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
| 89 |
+
elif grid_type == "slice":
|
| 90 |
+
max_h, max_w = max_size
|
| 91 |
+
grid_size_h, grid_size_w = grid_size
|
| 92 |
+
grid_h = np.arange(max_h, dtype=np.float32)
|
| 93 |
+
grid_w = np.arange(max_w, dtype=np.float32)
|
| 94 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError("Invalid value passed for `grid_type`.")
|
| 97 |
+
|
| 98 |
+
# Compute dimensions for each axis
|
| 99 |
+
dim_t = embed_dim // 4
|
| 100 |
+
dim_h = embed_dim // 8 * 3
|
| 101 |
+
dim_w = embed_dim // 8 * 3
|
| 102 |
+
|
| 103 |
+
# Temporal frequencies
|
| 104 |
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
| 105 |
+
# Spatial frequencies for height and width
|
| 106 |
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
| 107 |
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
| 108 |
+
|
| 109 |
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
| 110 |
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
| 111 |
+
freqs_t = freqs_t[:, None, None, :].expand(
|
| 112 |
+
-1, grid_size_h, grid_size_w, -1
|
| 113 |
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
| 114 |
+
freqs_h = freqs_h[None, :, None, :].expand(
|
| 115 |
+
temporal_size, -1, grid_size_w, -1
|
| 116 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
| 117 |
+
freqs_w = freqs_w[None, None, :, :].expand(
|
| 118 |
+
temporal_size, grid_size_h, -1, -1
|
| 119 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
| 120 |
+
|
| 121 |
+
freqs = torch.cat(
|
| 122 |
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
| 123 |
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
| 124 |
+
freqs = freqs.view(
|
| 125 |
+
temporal_size * grid_size_h * grid_size_w, -1
|
| 126 |
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
| 127 |
+
return freqs
|
| 128 |
+
|
| 129 |
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
| 130 |
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
| 131 |
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
| 132 |
+
|
| 133 |
+
if grid_type == "slice":
|
| 134 |
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
| 135 |
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
| 136 |
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
| 137 |
+
|
| 138 |
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
| 139 |
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
| 140 |
+
return cos, sin
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 144 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 145 |
+
tw = tgt_width
|
| 146 |
+
th = tgt_height
|
| 147 |
+
h, w = src
|
| 148 |
+
r = h / w
|
| 149 |
+
if r > (th / tw):
|
| 150 |
+
resize_height = th
|
| 151 |
+
resize_width = int(round(th / h * w))
|
| 152 |
+
else:
|
| 153 |
+
resize_width = tw
|
| 154 |
+
resize_height = int(round(tw / w * h))
|
| 155 |
+
|
| 156 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 157 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 158 |
+
|
| 159 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 163 |
+
def retrieve_timesteps(
|
| 164 |
+
scheduler,
|
| 165 |
+
num_inference_steps: Optional[int] = None,
|
| 166 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 167 |
+
timesteps: Optional[List[int]] = None,
|
| 168 |
+
sigmas: Optional[List[float]] = None,
|
| 169 |
+
**kwargs,
|
| 170 |
+
):
|
| 171 |
+
"""
|
| 172 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 173 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
scheduler (`SchedulerMixin`):
|
| 177 |
+
The scheduler to get timesteps from.
|
| 178 |
+
num_inference_steps (`int`):
|
| 179 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 180 |
+
must be `None`.
|
| 181 |
+
device (`str` or `torch.device`, *optional*):
|
| 182 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 183 |
+
timesteps (`List[int]`, *optional*):
|
| 184 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 185 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 186 |
+
sigmas (`List[float]`, *optional*):
|
| 187 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 188 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 192 |
+
second element is the number of inference steps.
|
| 193 |
+
"""
|
| 194 |
+
if timesteps is not None and sigmas is not None:
|
| 195 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 196 |
+
if timesteps is not None:
|
| 197 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 198 |
+
if not accepts_timesteps:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 201 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 202 |
+
)
|
| 203 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 204 |
+
timesteps = scheduler.timesteps
|
| 205 |
+
num_inference_steps = len(timesteps)
|
| 206 |
+
elif sigmas is not None:
|
| 207 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 208 |
+
if not accept_sigmas:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 211 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 212 |
+
)
|
| 213 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 214 |
+
timesteps = scheduler.timesteps
|
| 215 |
+
num_inference_steps = len(timesteps)
|
| 216 |
+
else:
|
| 217 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 218 |
+
timesteps = scheduler.timesteps
|
| 219 |
+
return timesteps, num_inference_steps
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 223 |
+
latent_size = latent.size()
|
| 224 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 225 |
+
|
| 226 |
+
if process_first_frame_only:
|
| 227 |
+
target_size = list(latent_size[2:])
|
| 228 |
+
target_size[0] = 1
|
| 229 |
+
first_frame_resized = F.interpolate(
|
| 230 |
+
mask[:, :, 0:1, :, :],
|
| 231 |
+
size=target_size,
|
| 232 |
+
mode='trilinear',
|
| 233 |
+
align_corners=False
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
target_size = list(latent_size[2:])
|
| 237 |
+
target_size[0] = target_size[0] - 1
|
| 238 |
+
if target_size[0] != 0:
|
| 239 |
+
remaining_frames_resized = F.interpolate(
|
| 240 |
+
mask[:, :, 1:, :, :],
|
| 241 |
+
size=target_size,
|
| 242 |
+
mode='trilinear',
|
| 243 |
+
align_corners=False
|
| 244 |
+
)
|
| 245 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 246 |
+
else:
|
| 247 |
+
resized_mask = first_frame_resized
|
| 248 |
+
else:
|
| 249 |
+
target_size = list(latent_size[2:])
|
| 250 |
+
resized_mask = F.interpolate(
|
| 251 |
+
mask,
|
| 252 |
+
size=target_size,
|
| 253 |
+
mode='trilinear',
|
| 254 |
+
align_corners=False
|
| 255 |
+
)
|
| 256 |
+
return resized_mask
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def add_noise_to_reference_video(image, ratio=None):
|
| 260 |
+
if ratio is None:
|
| 261 |
+
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
|
| 262 |
+
sigma = torch.exp(sigma).to(image.dtype)
|
| 263 |
+
else:
|
| 264 |
+
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
|
| 265 |
+
|
| 266 |
+
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
| 267 |
+
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
|
| 268 |
+
image = image + image_noise
|
| 269 |
+
return image
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
@dataclass
|
| 273 |
+
class CogVideoXFunPipelineOutput(BaseOutput):
|
| 274 |
+
r"""
|
| 275 |
+
Output class for CogVideo pipelines.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 279 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 280 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 281 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
videos: torch.Tensor
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class CogVideoXFunInpaintPipeline(DiffusionPipeline):
|
| 288 |
+
r"""
|
| 289 |
+
Pipeline for text-to-video generation using CogVideoX.
|
| 290 |
+
|
| 291 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 292 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
vae ([`AutoencoderKL`]):
|
| 296 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 297 |
+
text_encoder ([`T5EncoderModel`]):
|
| 298 |
+
Frozen text-encoder. CogVideoX_Fun uses
|
| 299 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 300 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 301 |
+
tokenizer (`T5Tokenizer`):
|
| 302 |
+
Tokenizer of class
|
| 303 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 304 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 305 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 306 |
+
scheduler ([`SchedulerMixin`]):
|
| 307 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
_optional_components = []
|
| 311 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 312 |
+
|
| 313 |
+
_callback_tensor_inputs = [
|
| 314 |
+
"latents",
|
| 315 |
+
"prompt_embeds",
|
| 316 |
+
"negative_prompt_embeds",
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
tokenizer: T5Tokenizer,
|
| 322 |
+
text_encoder: T5EncoderModel,
|
| 323 |
+
vae: AutoencoderKLCogVideoX,
|
| 324 |
+
transformer: CogVideoXTransformer3DModel,
|
| 325 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 326 |
+
):
|
| 327 |
+
super().__init__()
|
| 328 |
+
|
| 329 |
+
self.register_modules(
|
| 330 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 331 |
+
)
|
| 332 |
+
self.vae_scale_factor_spatial = (
|
| 333 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 334 |
+
)
|
| 335 |
+
self.vae_scale_factor_temporal = (
|
| 336 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 340 |
+
|
| 341 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 342 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 343 |
+
self.mask_processor = VaeImageProcessor(
|
| 344 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=False, do_convert_grayscale=True
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def _get_t5_prompt_embeds(
|
| 348 |
+
self,
|
| 349 |
+
prompt: Union[str, List[str]] = None,
|
| 350 |
+
num_videos_per_prompt: int = 1,
|
| 351 |
+
max_sequence_length: int = 226,
|
| 352 |
+
device: Optional[torch.device] = None,
|
| 353 |
+
dtype: Optional[torch.dtype] = None,
|
| 354 |
+
):
|
| 355 |
+
device = device or self._execution_device
|
| 356 |
+
dtype = dtype or self.text_encoder.dtype
|
| 357 |
+
|
| 358 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 359 |
+
batch_size = len(prompt)
|
| 360 |
+
|
| 361 |
+
text_inputs = self.tokenizer(
|
| 362 |
+
prompt,
|
| 363 |
+
padding="max_length",
|
| 364 |
+
max_length=max_sequence_length,
|
| 365 |
+
truncation=True,
|
| 366 |
+
add_special_tokens=True,
|
| 367 |
+
return_tensors="pt",
|
| 368 |
+
)
|
| 369 |
+
text_input_ids = text_inputs.input_ids
|
| 370 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 371 |
+
|
| 372 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 373 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 374 |
+
logger.warning(
|
| 375 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 376 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 380 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 381 |
+
|
| 382 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 383 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 384 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 385 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 386 |
+
|
| 387 |
+
return prompt_embeds
|
| 388 |
+
|
| 389 |
+
def encode_prompt(
|
| 390 |
+
self,
|
| 391 |
+
prompt: Union[str, List[str]],
|
| 392 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 393 |
+
do_classifier_free_guidance: bool = True,
|
| 394 |
+
num_videos_per_prompt: int = 1,
|
| 395 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 396 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 397 |
+
max_sequence_length: int = 226,
|
| 398 |
+
device: Optional[torch.device] = None,
|
| 399 |
+
dtype: Optional[torch.dtype] = None,
|
| 400 |
+
):
|
| 401 |
+
r"""
|
| 402 |
+
Encodes the prompt into text encoder hidden states.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 406 |
+
prompt to be encoded
|
| 407 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 408 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 409 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 410 |
+
less than `1`).
|
| 411 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 412 |
+
Whether to use classifier free guidance or not.
|
| 413 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 414 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 415 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 416 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 417 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 418 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 419 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 420 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 421 |
+
argument.
|
| 422 |
+
device: (`torch.device`, *optional*):
|
| 423 |
+
torch device
|
| 424 |
+
dtype: (`torch.dtype`, *optional*):
|
| 425 |
+
torch dtype
|
| 426 |
+
"""
|
| 427 |
+
device = device or self._execution_device
|
| 428 |
+
|
| 429 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 430 |
+
if prompt is not None:
|
| 431 |
+
batch_size = len(prompt)
|
| 432 |
+
else:
|
| 433 |
+
batch_size = prompt_embeds.shape[0]
|
| 434 |
+
|
| 435 |
+
if prompt_embeds is None:
|
| 436 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 437 |
+
prompt=prompt,
|
| 438 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 439 |
+
max_sequence_length=max_sequence_length,
|
| 440 |
+
device=device,
|
| 441 |
+
dtype=dtype,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 445 |
+
negative_prompt = negative_prompt or ""
|
| 446 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 447 |
+
|
| 448 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 449 |
+
raise TypeError(
|
| 450 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 451 |
+
f" {type(prompt)}."
|
| 452 |
+
)
|
| 453 |
+
elif batch_size != len(negative_prompt):
|
| 454 |
+
raise ValueError(
|
| 455 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 456 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 457 |
+
" the batch size of `prompt`."
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 461 |
+
prompt=negative_prompt,
|
| 462 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 463 |
+
max_sequence_length=max_sequence_length,
|
| 464 |
+
device=device,
|
| 465 |
+
dtype=dtype,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
return prompt_embeds, negative_prompt_embeds
|
| 469 |
+
|
| 470 |
+
def prepare_latents(
|
| 471 |
+
self,
|
| 472 |
+
batch_size,
|
| 473 |
+
num_channels_latents,
|
| 474 |
+
height,
|
| 475 |
+
width,
|
| 476 |
+
video_length,
|
| 477 |
+
dtype,
|
| 478 |
+
device,
|
| 479 |
+
generator,
|
| 480 |
+
latents=None,
|
| 481 |
+
video=None,
|
| 482 |
+
timestep=None,
|
| 483 |
+
is_strength_max=True,
|
| 484 |
+
return_noise=False,
|
| 485 |
+
return_video_latents=False,
|
| 486 |
+
):
|
| 487 |
+
shape = (
|
| 488 |
+
batch_size,
|
| 489 |
+
(video_length - 1) // self.vae_scale_factor_temporal + 1,
|
| 490 |
+
num_channels_latents,
|
| 491 |
+
height // self.vae_scale_factor_spatial,
|
| 492 |
+
width // self.vae_scale_factor_spatial,
|
| 493 |
+
)
|
| 494 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 497 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if return_video_latents or (latents is None and not is_strength_max):
|
| 501 |
+
video = video.to(device=device, dtype=self.vae.dtype)
|
| 502 |
+
|
| 503 |
+
bs = 1
|
| 504 |
+
new_video = []
|
| 505 |
+
for i in range(0, video.shape[0], bs):
|
| 506 |
+
video_bs = video[i : i + bs]
|
| 507 |
+
video_bs = self.vae.encode(video_bs)[0]
|
| 508 |
+
video_bs = video_bs.sample()
|
| 509 |
+
new_video.append(video_bs)
|
| 510 |
+
video = torch.cat(new_video, dim = 0)
|
| 511 |
+
video = video * self.vae.config.scaling_factor
|
| 512 |
+
|
| 513 |
+
video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
|
| 514 |
+
video_latents = video_latents.to(device=device, dtype=dtype)
|
| 515 |
+
video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
|
| 516 |
+
|
| 517 |
+
if latents is None:
|
| 518 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 519 |
+
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
| 520 |
+
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
|
| 521 |
+
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
| 522 |
+
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
| 523 |
+
else:
|
| 524 |
+
noise = latents.to(device)
|
| 525 |
+
latents = noise * self.scheduler.init_noise_sigma
|
| 526 |
+
|
| 527 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 528 |
+
outputs = (latents,)
|
| 529 |
+
|
| 530 |
+
if return_noise:
|
| 531 |
+
outputs += (noise,)
|
| 532 |
+
|
| 533 |
+
if return_video_latents:
|
| 534 |
+
outputs += (video_latents,)
|
| 535 |
+
|
| 536 |
+
return outputs
|
| 537 |
+
|
| 538 |
+
def prepare_mask_latents(
|
| 539 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
| 540 |
+
):
|
| 541 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 542 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 543 |
+
# and half precision
|
| 544 |
+
|
| 545 |
+
if mask is not None:
|
| 546 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
| 547 |
+
bs = 1
|
| 548 |
+
new_mask = []
|
| 549 |
+
for i in range(0, mask.shape[0], bs):
|
| 550 |
+
mask_bs = mask[i : i + bs]
|
| 551 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
| 552 |
+
mask_bs = mask_bs.mode()
|
| 553 |
+
new_mask.append(mask_bs)
|
| 554 |
+
mask = torch.cat(new_mask, dim = 0)
|
| 555 |
+
mask = mask * self.vae.config.scaling_factor
|
| 556 |
+
|
| 557 |
+
if masked_image is not None:
|
| 558 |
+
if self.transformer.config.add_noise_in_inpaint_model:
|
| 559 |
+
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
|
| 560 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
| 561 |
+
bs = 1
|
| 562 |
+
new_mask_pixel_values = []
|
| 563 |
+
for i in range(0, masked_image.shape[0], bs):
|
| 564 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
| 565 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
| 566 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
| 567 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
| 568 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
| 569 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 570 |
+
else:
|
| 571 |
+
masked_image_latents = None
|
| 572 |
+
|
| 573 |
+
return mask, masked_image_latents
|
| 574 |
+
|
| 575 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 576 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 577 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 578 |
+
|
| 579 |
+
frames = self.vae.decode(latents).sample
|
| 580 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 581 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 582 |
+
frames = frames.cpu().float().numpy()
|
| 583 |
+
return frames
|
| 584 |
+
|
| 585 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 586 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 587 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 588 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 589 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 590 |
+
# and should be between [0, 1]
|
| 591 |
+
|
| 592 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 593 |
+
extra_step_kwargs = {}
|
| 594 |
+
if accepts_eta:
|
| 595 |
+
extra_step_kwargs["eta"] = eta
|
| 596 |
+
|
| 597 |
+
# check if the scheduler accepts generator
|
| 598 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 599 |
+
if accepts_generator:
|
| 600 |
+
extra_step_kwargs["generator"] = generator
|
| 601 |
+
return extra_step_kwargs
|
| 602 |
+
|
| 603 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 604 |
+
def check_inputs(
|
| 605 |
+
self,
|
| 606 |
+
prompt,
|
| 607 |
+
height,
|
| 608 |
+
width,
|
| 609 |
+
negative_prompt,
|
| 610 |
+
callback_on_step_end_tensor_inputs,
|
| 611 |
+
prompt_embeds=None,
|
| 612 |
+
negative_prompt_embeds=None,
|
| 613 |
+
):
|
| 614 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 615 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 616 |
+
|
| 617 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 618 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 619 |
+
):
|
| 620 |
+
raise ValueError(
|
| 621 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 622 |
+
)
|
| 623 |
+
if prompt is not None and prompt_embeds is not None:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 626 |
+
" only forward one of the two."
|
| 627 |
+
)
|
| 628 |
+
elif prompt is None and prompt_embeds is None:
|
| 629 |
+
raise ValueError(
|
| 630 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 631 |
+
)
|
| 632 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 633 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 634 |
+
|
| 635 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 636 |
+
raise ValueError(
|
| 637 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 638 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 642 |
+
raise ValueError(
|
| 643 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 644 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 648 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 649 |
+
raise ValueError(
|
| 650 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 651 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 652 |
+
f" {negative_prompt_embeds.shape}."
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def fuse_qkv_projections(self) -> None:
|
| 656 |
+
r"""Enables fused QKV projections."""
|
| 657 |
+
self.fusing_transformer = True
|
| 658 |
+
self.transformer.fuse_qkv_projections()
|
| 659 |
+
|
| 660 |
+
def unfuse_qkv_projections(self) -> None:
|
| 661 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 662 |
+
if not self.fusing_transformer:
|
| 663 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 664 |
+
else:
|
| 665 |
+
self.transformer.unfuse_qkv_projections()
|
| 666 |
+
self.fusing_transformer = False
|
| 667 |
+
|
| 668 |
+
def _prepare_rotary_positional_embeddings(
|
| 669 |
+
self,
|
| 670 |
+
height: int,
|
| 671 |
+
width: int,
|
| 672 |
+
num_frames: int,
|
| 673 |
+
device: torch.device,
|
| 674 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 675 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 676 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 677 |
+
|
| 678 |
+
p = self.transformer.config.patch_size
|
| 679 |
+
p_t = self.transformer.config.patch_size_t
|
| 680 |
+
|
| 681 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 682 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 683 |
+
|
| 684 |
+
if p_t is None:
|
| 685 |
+
# CogVideoX 1.0
|
| 686 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 687 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 688 |
+
)
|
| 689 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 690 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 691 |
+
crops_coords=grid_crops_coords,
|
| 692 |
+
grid_size=(grid_height, grid_width),
|
| 693 |
+
temporal_size=num_frames,
|
| 694 |
+
)
|
| 695 |
+
else:
|
| 696 |
+
# CogVideoX 1.5
|
| 697 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 698 |
+
|
| 699 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 700 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 701 |
+
crops_coords=None,
|
| 702 |
+
grid_size=(grid_height, grid_width),
|
| 703 |
+
temporal_size=base_num_frames,
|
| 704 |
+
grid_type="slice",
|
| 705 |
+
max_size=(base_size_height, base_size_width),
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
freqs_cos = freqs_cos.to(device=device)
|
| 709 |
+
freqs_sin = freqs_sin.to(device=device)
|
| 710 |
+
return freqs_cos, freqs_sin
|
| 711 |
+
|
| 712 |
+
@property
|
| 713 |
+
def guidance_scale(self):
|
| 714 |
+
return self._guidance_scale
|
| 715 |
+
|
| 716 |
+
@property
|
| 717 |
+
def num_timesteps(self):
|
| 718 |
+
return self._num_timesteps
|
| 719 |
+
|
| 720 |
+
@property
|
| 721 |
+
def attention_kwargs(self):
|
| 722 |
+
return self._attention_kwargs
|
| 723 |
+
|
| 724 |
+
@property
|
| 725 |
+
def interrupt(self):
|
| 726 |
+
return self._interrupt
|
| 727 |
+
|
| 728 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
| 729 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
| 730 |
+
# get the original timestep using init_timestep
|
| 731 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 732 |
+
|
| 733 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 734 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 735 |
+
|
| 736 |
+
return timesteps, num_inference_steps - t_start
|
| 737 |
+
|
| 738 |
+
@torch.no_grad()
|
| 739 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 740 |
+
def __call__(
|
| 741 |
+
self,
|
| 742 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 743 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 744 |
+
height: int = 480,
|
| 745 |
+
width: int = 720,
|
| 746 |
+
video: Union[torch.FloatTensor] = None,
|
| 747 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 748 |
+
masked_video_latents: Union[torch.FloatTensor] = None,
|
| 749 |
+
num_frames: int = 49,
|
| 750 |
+
num_inference_steps: int = 50,
|
| 751 |
+
timesteps: Optional[List[int]] = None,
|
| 752 |
+
guidance_scale: float = 6,
|
| 753 |
+
use_dynamic_cfg: bool = False,
|
| 754 |
+
num_videos_per_prompt: int = 1,
|
| 755 |
+
eta: float = 0.0,
|
| 756 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 757 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 758 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 759 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 760 |
+
output_type: str = "numpy",
|
| 761 |
+
return_dict: bool = False,
|
| 762 |
+
callback_on_step_end: Optional[
|
| 763 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 764 |
+
] = None,
|
| 765 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 766 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 767 |
+
max_sequence_length: int = 226,
|
| 768 |
+
strength: float = 1,
|
| 769 |
+
noise_aug_strength: float = 0.0563,
|
| 770 |
+
comfyui_progressbar: bool = False,
|
| 771 |
+
temporal_multidiffusion_stride: int = 16,
|
| 772 |
+
use_trimask: bool = False,
|
| 773 |
+
zero_out_mask_region: bool = False,
|
| 774 |
+
binarize_mask: bool = False,
|
| 775 |
+
skip_unet: bool = False,
|
| 776 |
+
use_vae_mask: bool = False,
|
| 777 |
+
stack_mask: bool = False,
|
| 778 |
+
) -> Union[CogVideoXFunPipelineOutput, Tuple]:
|
| 779 |
+
"""
|
| 780 |
+
Function invoked when calling the pipeline for generation.
|
| 781 |
+
|
| 782 |
+
Args:
|
| 783 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 784 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 785 |
+
instead.
|
| 786 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 787 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 788 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 789 |
+
less than `1`).
|
| 790 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 791 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 792 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 793 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 794 |
+
num_frames (`int`, defaults to `48`):
|
| 795 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 796 |
+
contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
|
| 797 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 798 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 799 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 800 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 801 |
+
expense of slower inference.
|
| 802 |
+
timesteps (`List[int]`, *optional*):
|
| 803 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 804 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 805 |
+
passed will be used. Must be in descending order.
|
| 806 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 807 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 808 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 809 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 810 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 811 |
+
usually at the expense of lower image quality.
|
| 812 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 813 |
+
The number of videos to generate per prompt.
|
| 814 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 815 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 816 |
+
to make generation deterministic.
|
| 817 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 818 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 819 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 820 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 821 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 822 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 823 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 824 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 825 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 826 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 827 |
+
argument.
|
| 828 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 829 |
+
The output format of the generate image. Choose between
|
| 830 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 831 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 832 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 833 |
+
of a plain tuple.
|
| 834 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 835 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 836 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 837 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 838 |
+
`callback_on_step_end_tensor_inputs`.
|
| 839 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 840 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 841 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 842 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 843 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 844 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 845 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 846 |
+
|
| 847 |
+
Examples:
|
| 848 |
+
|
| 849 |
+
Returns:
|
| 850 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
|
| 851 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
|
| 852 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 853 |
+
"""
|
| 854 |
+
|
| 855 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 856 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 857 |
+
|
| 858 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 859 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 860 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 861 |
+
|
| 862 |
+
num_videos_per_prompt = 1
|
| 863 |
+
|
| 864 |
+
# 1. Check inputs. Raise error if not correct
|
| 865 |
+
self.check_inputs(
|
| 866 |
+
prompt,
|
| 867 |
+
height,
|
| 868 |
+
width,
|
| 869 |
+
negative_prompt,
|
| 870 |
+
callback_on_step_end_tensor_inputs,
|
| 871 |
+
prompt_embeds,
|
| 872 |
+
negative_prompt_embeds,
|
| 873 |
+
)
|
| 874 |
+
self._guidance_scale = guidance_scale
|
| 875 |
+
self._attention_kwargs = attention_kwargs
|
| 876 |
+
self._interrupt = False
|
| 877 |
+
|
| 878 |
+
# 2. Default call parameters
|
| 879 |
+
if prompt is not None and isinstance(prompt, str):
|
| 880 |
+
batch_size = 1
|
| 881 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 882 |
+
batch_size = len(prompt)
|
| 883 |
+
else:
|
| 884 |
+
batch_size = prompt_embeds.shape[0]
|
| 885 |
+
|
| 886 |
+
device = self._execution_device
|
| 887 |
+
|
| 888 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 889 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 890 |
+
# corresponds to doing no classifier free guidance.
|
| 891 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 892 |
+
logger.info(f'Use cfg: {do_classifier_free_guidance}, guidance_scale={guidance_scale}')
|
| 893 |
+
|
| 894 |
+
# 3. Encode input prompt
|
| 895 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 896 |
+
prompt,
|
| 897 |
+
negative_prompt,
|
| 898 |
+
do_classifier_free_guidance,
|
| 899 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 900 |
+
prompt_embeds=prompt_embeds,
|
| 901 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 902 |
+
max_sequence_length=max_sequence_length,
|
| 903 |
+
device=device,
|
| 904 |
+
)
|
| 905 |
+
if do_classifier_free_guidance:
|
| 906 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 907 |
+
|
| 908 |
+
# 4. set timesteps
|
| 909 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 910 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
| 911 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
| 912 |
+
)
|
| 913 |
+
self._num_timesteps = len(timesteps)
|
| 914 |
+
if comfyui_progressbar:
|
| 915 |
+
from comfy.utils import ProgressBar
|
| 916 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 917 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
| 918 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
| 919 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
| 920 |
+
is_strength_max = strength == 1.0
|
| 921 |
+
|
| 922 |
+
# 5. Prepare latents.
|
| 923 |
+
if video is not None:
|
| 924 |
+
video_length = video.shape[2]
|
| 925 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 926 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 927 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 928 |
+
else:
|
| 929 |
+
video_length = num_frames
|
| 930 |
+
init_video = None
|
| 931 |
+
|
| 932 |
+
# Magvae needs the number of frames to be 4n + 1.
|
| 933 |
+
local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 934 |
+
# For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
|
| 935 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 936 |
+
additional_frames = 0
|
| 937 |
+
if patch_size_t is not None and local_latent_length % patch_size_t != 0:
|
| 938 |
+
additional_frames = local_latent_length % patch_size_t
|
| 939 |
+
num_frames -= additional_frames * self.vae_scale_factor_temporal
|
| 940 |
+
if num_frames <= 0:
|
| 941 |
+
num_frames = 1
|
| 942 |
+
|
| 943 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 944 |
+
num_channels_transformer = self.transformer.config.in_channels
|
| 945 |
+
return_image_latents = num_channels_transformer == num_channels_latents
|
| 946 |
+
|
| 947 |
+
latents_outputs = self.prepare_latents(
|
| 948 |
+
batch_size * num_videos_per_prompt,
|
| 949 |
+
num_channels_latents,
|
| 950 |
+
height,
|
| 951 |
+
width,
|
| 952 |
+
video_length,
|
| 953 |
+
prompt_embeds.dtype,
|
| 954 |
+
device,
|
| 955 |
+
generator,
|
| 956 |
+
latents,
|
| 957 |
+
video=init_video,
|
| 958 |
+
timestep=latent_timestep,
|
| 959 |
+
is_strength_max=is_strength_max,
|
| 960 |
+
return_noise=True,
|
| 961 |
+
return_video_latents=return_image_latents,
|
| 962 |
+
)
|
| 963 |
+
if return_image_latents:
|
| 964 |
+
latents, noise, image_latents = latents_outputs
|
| 965 |
+
else:
|
| 966 |
+
latents, noise = latents_outputs
|
| 967 |
+
if comfyui_progressbar:
|
| 968 |
+
pbar.update(1)
|
| 969 |
+
|
| 970 |
+
if mask_video is not None:
|
| 971 |
+
if (mask_video == 255).all():
|
| 972 |
+
mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
|
| 973 |
+
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
| 974 |
+
|
| 975 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 976 |
+
masked_video_latents_input = (
|
| 977 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 978 |
+
)
|
| 979 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
|
| 980 |
+
else:
|
| 981 |
+
# Prepare mask latent variables
|
| 982 |
+
video_length = video.shape[2]
|
| 983 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 984 |
+
if use_trimask:
|
| 985 |
+
mask_condition = torch.where(mask_condition > 0.75, 1., mask_condition)
|
| 986 |
+
mask_condition = torch.where((mask_condition <= 0.75) * (mask_condition >= 0.25), 127. / 255., mask_condition)
|
| 987 |
+
mask_condition = torch.where(mask_condition < 0.25, 0., mask_condition)
|
| 988 |
+
else:
|
| 989 |
+
mask_condition = torch.where(mask_condition > 0.5, 1., 0.)
|
| 990 |
+
|
| 991 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 992 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 993 |
+
|
| 994 |
+
if num_channels_transformer != num_channels_latents:
|
| 995 |
+
mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
|
| 996 |
+
if masked_video_latents is None:
|
| 997 |
+
if zero_out_mask_region:
|
| 998 |
+
masked_video = init_video * (mask_condition_tile < 0.75) + torch.ones_like(init_video) * (mask_condition_tile > 0.75) * -1
|
| 999 |
+
else:
|
| 1000 |
+
masked_video = init_video
|
| 1001 |
+
else:
|
| 1002 |
+
masked_video = masked_video_latents
|
| 1003 |
+
|
| 1004 |
+
mask_encoded, masked_video_latents = self.prepare_mask_latents(
|
| 1005 |
+
1 - mask_condition_tile if use_vae_mask else None,
|
| 1006 |
+
masked_video,
|
| 1007 |
+
batch_size,
|
| 1008 |
+
height,
|
| 1009 |
+
width,
|
| 1010 |
+
prompt_embeds.dtype,
|
| 1011 |
+
device,
|
| 1012 |
+
generator,
|
| 1013 |
+
do_classifier_free_guidance,
|
| 1014 |
+
noise_aug_strength=noise_aug_strength,
|
| 1015 |
+
)
|
| 1016 |
+
if not use_vae_mask and not stack_mask:
|
| 1017 |
+
mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
|
| 1018 |
+
if binarize_mask:
|
| 1019 |
+
if use_trimask:
|
| 1020 |
+
mask_latents = torch.where(mask_latents > 0.75, 1., mask_latents)
|
| 1021 |
+
mask_latents = torch.where((mask_latents <= 0.75) * (mask_latents >= 0.25), 0.5, mask_latents)
|
| 1022 |
+
mask_latents = torch.where(mask_latents < 0.25, 0., mask_latents)
|
| 1023 |
+
else:
|
| 1024 |
+
mask_latents = torch.where(mask_latents < 0.9, 0., 1.).to(mask_latents.dtype)
|
| 1025 |
+
|
| 1026 |
+
mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
|
| 1027 |
+
|
| 1028 |
+
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
|
| 1029 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
| 1030 |
+
|
| 1031 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 1032 |
+
mask = rearrange(mask, "b c f h w -> b f c h w")
|
| 1033 |
+
elif stack_mask:
|
| 1034 |
+
mask_latents = torch.cat([
|
| 1035 |
+
torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
|
| 1036 |
+
mask_condition[:, :, 1:],
|
| 1037 |
+
], dim=2)
|
| 1038 |
+
mask_latents = mask_latents.view(
|
| 1039 |
+
mask_latents.shape[0],
|
| 1040 |
+
mask_latents.shape[2] // 4,
|
| 1041 |
+
4,
|
| 1042 |
+
mask_latents.shape[3],
|
| 1043 |
+
mask_latents.shape[4],
|
| 1044 |
+
)
|
| 1045 |
+
mask_latents = mask_latents.transpose(1, 2)
|
| 1046 |
+
mask_latents = resize_mask(1 - mask_latents, masked_video_latents).to(latents.device, latents.dtype)
|
| 1047 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 1048 |
+
else:
|
| 1049 |
+
mask_input = (
|
| 1050 |
+
torch.cat([mask_encoded] * 2) if do_classifier_free_guidance else mask_encoded
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
masked_video_latents_input = (
|
| 1054 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
|
| 1058 |
+
masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
|
| 1059 |
+
|
| 1060 |
+
# concat(binary mask, encode(mask * video))
|
| 1061 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
|
| 1062 |
+
else:
|
| 1063 |
+
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
|
| 1064 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
| 1065 |
+
mask = rearrange(mask, "b c f h w -> b f c h w")
|
| 1066 |
+
|
| 1067 |
+
inpaint_latents = None
|
| 1068 |
+
else:
|
| 1069 |
+
if num_channels_transformer != num_channels_latents:
|
| 1070 |
+
mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
| 1071 |
+
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
| 1072 |
+
|
| 1073 |
+
mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
| 1074 |
+
masked_video_latents_input = (
|
| 1075 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 1076 |
+
)
|
| 1077 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
| 1078 |
+
else:
|
| 1079 |
+
mask = torch.zeros_like(init_video[:, :1])
|
| 1080 |
+
mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
|
| 1081 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
| 1082 |
+
mask = rearrange(mask, "b c f h w -> b f c h w")
|
| 1083 |
+
|
| 1084 |
+
inpaint_latents = None
|
| 1085 |
+
if comfyui_progressbar:
|
| 1086 |
+
pbar.update(1)
|
| 1087 |
+
|
| 1088 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1089 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1090 |
+
logger.debug(f'Pipeline mask {mask_condition.shape} {mask_condition.dtype} {mask_condition.min()} {mask_condition.max()}')
|
| 1091 |
+
# 8. Denoising loop
|
| 1092 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1093 |
+
latent_temporal_window_size = (num_frames - 1) // 4 + 1
|
| 1094 |
+
if latents.size(1) > latent_temporal_window_size:
|
| 1095 |
+
logger.info(f'Adopt temporal multidiffusion for the latents {latents.shape} {latents.dtype}')
|
| 1096 |
+
|
| 1097 |
+
# VAE experiment
|
| 1098 |
+
if skip_unet:
|
| 1099 |
+
masked_video_latents = rearrange(masked_video_latents, "b c f h w -> b f c h w")
|
| 1100 |
+
if output_type == "numpy":
|
| 1101 |
+
video = self.decode_latents(masked_video_latents)
|
| 1102 |
+
elif not output_type == "latent":
|
| 1103 |
+
video = self.decode_latents(masked_video_latents)
|
| 1104 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 1105 |
+
else:
|
| 1106 |
+
video = masked_video_latents
|
| 1107 |
+
|
| 1108 |
+
# Offload all models
|
| 1109 |
+
self.maybe_free_model_hooks()
|
| 1110 |
+
|
| 1111 |
+
if not return_dict:
|
| 1112 |
+
video = torch.from_numpy(video)
|
| 1113 |
+
|
| 1114 |
+
return CogVideoXFunPipelineOutput(videos=video)
|
| 1115 |
+
|
| 1116 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1117 |
+
# for DPM-solver++
|
| 1118 |
+
old_pred_original_sample = None
|
| 1119 |
+
for i, t in enumerate(timesteps):
|
| 1120 |
+
if self.interrupt:
|
| 1121 |
+
continue
|
| 1122 |
+
|
| 1123 |
+
def _sample(_latents, _inpaint_latents):
|
| 1124 |
+
# 7. Create rotary embeds if required
|
| 1125 |
+
image_rotary_emb = (
|
| 1126 |
+
self._prepare_rotary_positional_embeddings(height, width, _latents.size(1), device)
|
| 1127 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 1128 |
+
else None
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
latent_model_input = torch.cat([_latents] * 2) if do_classifier_free_guidance else _latents
|
| 1132 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1133 |
+
|
| 1134 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1135 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 1136 |
+
|
| 1137 |
+
# predict noise model_output
|
| 1138 |
+
noise_pred = self.transformer(
|
| 1139 |
+
hidden_states=latent_model_input,
|
| 1140 |
+
encoder_hidden_states=prompt_embeds,
|
| 1141 |
+
timestep=timestep,
|
| 1142 |
+
image_rotary_emb=image_rotary_emb,
|
| 1143 |
+
return_dict=False,
|
| 1144 |
+
inpaint_latents=_inpaint_latents,
|
| 1145 |
+
)[0]
|
| 1146 |
+
noise_pred = noise_pred.float()
|
| 1147 |
+
|
| 1148 |
+
# perform guidance
|
| 1149 |
+
if use_dynamic_cfg:
|
| 1150 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 1151 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 1152 |
+
)
|
| 1153 |
+
if do_classifier_free_guidance:
|
| 1154 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1155 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1156 |
+
|
| 1157 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1158 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 1159 |
+
_latents = self.scheduler.step(noise_pred, t, _latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1160 |
+
else:
|
| 1161 |
+
_latents, old_pred_original_sample = self.scheduler.step(
|
| 1162 |
+
noise_pred,
|
| 1163 |
+
old_pred_original_sample,
|
| 1164 |
+
t,
|
| 1165 |
+
timesteps[i - 1] if i > 0 else None,
|
| 1166 |
+
_latents,
|
| 1167 |
+
**extra_step_kwargs,
|
| 1168 |
+
return_dict=False,
|
| 1169 |
+
)
|
| 1170 |
+
_latents = _latents.to(prompt_embeds.dtype)
|
| 1171 |
+
return _latents
|
| 1172 |
+
|
| 1173 |
+
if latents.size(1) <= latent_temporal_window_size:
|
| 1174 |
+
latents = _sample(latents, inpaint_latents)
|
| 1175 |
+
else:
|
| 1176 |
+
# adopt temporal multidiffusion
|
| 1177 |
+
latents_canvas = torch.zeros_like(latents).float()
|
| 1178 |
+
weights_canvas = torch.zeros(1, latents.size(1), 1, 1, 1).to(latents.device).float()
|
| 1179 |
+
temporal_stride = temporal_multidiffusion_stride // 4
|
| 1180 |
+
assert latent_temporal_window_size > temporal_stride
|
| 1181 |
+
|
| 1182 |
+
time_beg = 0
|
| 1183 |
+
while time_beg < latents.size(1):
|
| 1184 |
+
time_end = min(time_beg + latent_temporal_window_size, latents.size(1))
|
| 1185 |
+
|
| 1186 |
+
latents_i = latents[:, time_beg:time_end]
|
| 1187 |
+
if inpaint_latents is not None:
|
| 1188 |
+
inpaint_latents_i = inpaint_latents[:, time_beg:time_end]
|
| 1189 |
+
else:
|
| 1190 |
+
inpaint_latents_i = None
|
| 1191 |
+
|
| 1192 |
+
latents_i = _sample(latents_i, inpaint_latents_i)
|
| 1193 |
+
|
| 1194 |
+
weights_i = torch.ones(1, time_end - time_beg, 1, 1, 1).to(latents.device).to(latents.dtype)
|
| 1195 |
+
if time_beg > 0 and temporal_stride > 0:
|
| 1196 |
+
weights_i[:, :temporal_stride] = (torch.linspace(0., 1., temporal_stride + 2)[1:-1]
|
| 1197 |
+
.to(latents.device)
|
| 1198 |
+
.to(latents.dtype)
|
| 1199 |
+
.reshape(1, temporal_stride, 1, 1, 1))
|
| 1200 |
+
if time_end < latents.size(1) and temporal_stride > 0:
|
| 1201 |
+
weights_i[:, -temporal_stride:] = (torch.linspace(1., 0., temporal_stride + 2)[1:-1]
|
| 1202 |
+
.to(latents.device)
|
| 1203 |
+
.to(latents.dtype)
|
| 1204 |
+
.reshape(1, temporal_stride, 1, 1, 1))
|
| 1205 |
+
|
| 1206 |
+
latents_canvas[:, time_beg:time_end] += latents_i * weights_i
|
| 1207 |
+
weights_canvas[:, time_beg:time_end] += weights_i
|
| 1208 |
+
|
| 1209 |
+
time_beg = time_end - temporal_stride
|
| 1210 |
+
if time_end >= latents.size(1):
|
| 1211 |
+
break
|
| 1212 |
+
latents = (latents_canvas / weights_canvas).to(latents.dtype)
|
| 1213 |
+
|
| 1214 |
+
# call the callback, if provided
|
| 1215 |
+
if callback_on_step_end is not None:
|
| 1216 |
+
callback_kwargs = {}
|
| 1217 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1218 |
+
callback_kwargs[k] = locals()[k]
|
| 1219 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1220 |
+
|
| 1221 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1222 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1223 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1224 |
+
|
| 1225 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1226 |
+
progress_bar.update()
|
| 1227 |
+
if comfyui_progressbar:
|
| 1228 |
+
pbar.update(1)
|
| 1229 |
+
|
| 1230 |
+
if output_type == "numpy":
|
| 1231 |
+
video = self.decode_latents(latents)
|
| 1232 |
+
elif not output_type == "latent":
|
| 1233 |
+
video = self.decode_latents(latents)
|
| 1234 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 1235 |
+
else:
|
| 1236 |
+
video = latents
|
| 1237 |
+
|
| 1238 |
+
# Offload all models
|
| 1239 |
+
self.maybe_free_model_hooks()
|
| 1240 |
+
|
| 1241 |
+
if not return_dict:
|
| 1242 |
+
video = torch.from_numpy(video)
|
| 1243 |
+
|
| 1244 |
+
return CogVideoXFunPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan_fun.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 9 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 10 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 11 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 12 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 13 |
+
from diffusers.video_processor import VideoProcessor
|
| 14 |
+
|
| 15 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer,
|
| 16 |
+
WanT5EncoderModel, WanTransformer3DModel)
|
| 17 |
+
|
| 18 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
EXAMPLE_DOC_STRING = """
|
| 22 |
+
Examples:
|
| 23 |
+
```python
|
| 24 |
+
pass
|
| 25 |
+
```
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 30 |
+
def retrieve_timesteps(
|
| 31 |
+
scheduler,
|
| 32 |
+
num_inference_steps: Optional[int] = None,
|
| 33 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 34 |
+
timesteps: Optional[List[int]] = None,
|
| 35 |
+
sigmas: Optional[List[float]] = None,
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 40 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
scheduler (`SchedulerMixin`):
|
| 44 |
+
The scheduler to get timesteps from.
|
| 45 |
+
num_inference_steps (`int`):
|
| 46 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 47 |
+
must be `None`.
|
| 48 |
+
device (`str` or `torch.device`, *optional*):
|
| 49 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 50 |
+
timesteps (`List[int]`, *optional*):
|
| 51 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 52 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 53 |
+
sigmas (`List[float]`, *optional*):
|
| 54 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 55 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 59 |
+
second element is the number of inference steps.
|
| 60 |
+
"""
|
| 61 |
+
if timesteps is not None and sigmas is not None:
|
| 62 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 63 |
+
if timesteps is not None:
|
| 64 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 65 |
+
if not accepts_timesteps:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 68 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 69 |
+
)
|
| 70 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 71 |
+
timesteps = scheduler.timesteps
|
| 72 |
+
num_inference_steps = len(timesteps)
|
| 73 |
+
elif sigmas is not None:
|
| 74 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 75 |
+
if not accept_sigmas:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 78 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 79 |
+
)
|
| 80 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 81 |
+
timesteps = scheduler.timesteps
|
| 82 |
+
num_inference_steps = len(timesteps)
|
| 83 |
+
else:
|
| 84 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 85 |
+
timesteps = scheduler.timesteps
|
| 86 |
+
return timesteps, num_inference_steps
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@dataclass
|
| 90 |
+
class WanPipelineOutput(BaseOutput):
|
| 91 |
+
r"""
|
| 92 |
+
Output class for CogVideo pipelines.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 96 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 97 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 98 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
videos: torch.Tensor
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class WanFunPipeline(DiffusionPipeline):
|
| 105 |
+
r"""
|
| 106 |
+
Pipeline for text-to-video generation using Wan.
|
| 107 |
+
|
| 108 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 109 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
_optional_components = []
|
| 113 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 114 |
+
|
| 115 |
+
_callback_tensor_inputs = [
|
| 116 |
+
"latents",
|
| 117 |
+
"prompt_embeds",
|
| 118 |
+
"negative_prompt_embeds",
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
tokenizer: AutoTokenizer,
|
| 124 |
+
text_encoder: WanT5EncoderModel,
|
| 125 |
+
vae: AutoencoderKLWan,
|
| 126 |
+
transformer: WanTransformer3DModel,
|
| 127 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 128 |
+
):
|
| 129 |
+
super().__init__()
|
| 130 |
+
|
| 131 |
+
self.register_modules(
|
| 132 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 133 |
+
)
|
| 134 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
|
| 135 |
+
|
| 136 |
+
def _get_t5_prompt_embeds(
|
| 137 |
+
self,
|
| 138 |
+
prompt: Union[str, List[str]] = None,
|
| 139 |
+
num_videos_per_prompt: int = 1,
|
| 140 |
+
max_sequence_length: int = 512,
|
| 141 |
+
device: Optional[torch.device] = None,
|
| 142 |
+
dtype: Optional[torch.dtype] = None,
|
| 143 |
+
):
|
| 144 |
+
device = device or self._execution_device
|
| 145 |
+
dtype = dtype or self.text_encoder.dtype
|
| 146 |
+
|
| 147 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 148 |
+
batch_size = len(prompt)
|
| 149 |
+
|
| 150 |
+
text_inputs = self.tokenizer(
|
| 151 |
+
prompt,
|
| 152 |
+
padding="max_length",
|
| 153 |
+
max_length=max_sequence_length,
|
| 154 |
+
truncation=True,
|
| 155 |
+
add_special_tokens=True,
|
| 156 |
+
return_tensors="pt",
|
| 157 |
+
)
|
| 158 |
+
text_input_ids = text_inputs.input_ids
|
| 159 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 160 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 161 |
+
|
| 162 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 163 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 164 |
+
logger.warning(
|
| 165 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 166 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 170 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 171 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 172 |
+
|
| 173 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 174 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 175 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 176 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 177 |
+
|
| 178 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 179 |
+
|
| 180 |
+
def encode_prompt(
|
| 181 |
+
self,
|
| 182 |
+
prompt: Union[str, List[str]],
|
| 183 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 184 |
+
do_classifier_free_guidance: bool = True,
|
| 185 |
+
num_videos_per_prompt: int = 1,
|
| 186 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 187 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 188 |
+
max_sequence_length: int = 512,
|
| 189 |
+
device: Optional[torch.device] = None,
|
| 190 |
+
dtype: Optional[torch.dtype] = None,
|
| 191 |
+
):
|
| 192 |
+
r"""
|
| 193 |
+
Encodes the prompt into text encoder hidden states.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 197 |
+
prompt to be encoded
|
| 198 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 199 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 200 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 201 |
+
less than `1`).
|
| 202 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 203 |
+
Whether to use classifier free guidance or not.
|
| 204 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 205 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 206 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 207 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 208 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 209 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 210 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 211 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 212 |
+
argument.
|
| 213 |
+
device: (`torch.device`, *optional*):
|
| 214 |
+
torch device
|
| 215 |
+
dtype: (`torch.dtype`, *optional*):
|
| 216 |
+
torch dtype
|
| 217 |
+
"""
|
| 218 |
+
device = device or self._execution_device
|
| 219 |
+
|
| 220 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 221 |
+
if prompt is not None:
|
| 222 |
+
batch_size = len(prompt)
|
| 223 |
+
else:
|
| 224 |
+
batch_size = prompt_embeds.shape[0]
|
| 225 |
+
|
| 226 |
+
if prompt_embeds is None:
|
| 227 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 228 |
+
prompt=prompt,
|
| 229 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 230 |
+
max_sequence_length=max_sequence_length,
|
| 231 |
+
device=device,
|
| 232 |
+
dtype=dtype,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 236 |
+
negative_prompt = negative_prompt or ""
|
| 237 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 238 |
+
|
| 239 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 240 |
+
raise TypeError(
|
| 241 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 242 |
+
f" {type(prompt)}."
|
| 243 |
+
)
|
| 244 |
+
elif batch_size != len(negative_prompt):
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 247 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 248 |
+
" the batch size of `prompt`."
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 252 |
+
prompt=negative_prompt,
|
| 253 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 254 |
+
max_sequence_length=max_sequence_length,
|
| 255 |
+
device=device,
|
| 256 |
+
dtype=dtype,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return prompt_embeds, negative_prompt_embeds
|
| 260 |
+
|
| 261 |
+
def prepare_latents(
|
| 262 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 263 |
+
):
|
| 264 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 265 |
+
raise ValueError(
|
| 266 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 267 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
shape = (
|
| 271 |
+
batch_size,
|
| 272 |
+
num_channels_latents,
|
| 273 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 274 |
+
height // self.vae.spacial_compression_ratio,
|
| 275 |
+
width // self.vae.spacial_compression_ratio,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if latents is None:
|
| 279 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 280 |
+
else:
|
| 281 |
+
latents = latents.to(device)
|
| 282 |
+
|
| 283 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 284 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 285 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 286 |
+
return latents
|
| 287 |
+
|
| 288 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 289 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 290 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 291 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 292 |
+
frames = frames.cpu().float().numpy()
|
| 293 |
+
return frames
|
| 294 |
+
|
| 295 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 296 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 297 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 298 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 299 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 300 |
+
# and should be between [0, 1]
|
| 301 |
+
|
| 302 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 303 |
+
extra_step_kwargs = {}
|
| 304 |
+
if accepts_eta:
|
| 305 |
+
extra_step_kwargs["eta"] = eta
|
| 306 |
+
|
| 307 |
+
# check if the scheduler accepts generator
|
| 308 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 309 |
+
if accepts_generator:
|
| 310 |
+
extra_step_kwargs["generator"] = generator
|
| 311 |
+
return extra_step_kwargs
|
| 312 |
+
|
| 313 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 314 |
+
def check_inputs(
|
| 315 |
+
self,
|
| 316 |
+
prompt,
|
| 317 |
+
height,
|
| 318 |
+
width,
|
| 319 |
+
negative_prompt,
|
| 320 |
+
callback_on_step_end_tensor_inputs,
|
| 321 |
+
prompt_embeds=None,
|
| 322 |
+
negative_prompt_embeds=None,
|
| 323 |
+
):
|
| 324 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 325 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 326 |
+
|
| 327 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 328 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 329 |
+
):
|
| 330 |
+
raise ValueError(
|
| 331 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 332 |
+
)
|
| 333 |
+
if prompt is not None and prompt_embeds is not None:
|
| 334 |
+
raise ValueError(
|
| 335 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 336 |
+
" only forward one of the two."
|
| 337 |
+
)
|
| 338 |
+
elif prompt is None and prompt_embeds is None:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 341 |
+
)
|
| 342 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 343 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 344 |
+
|
| 345 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 346 |
+
raise ValueError(
|
| 347 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 348 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 352 |
+
raise ValueError(
|
| 353 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 354 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 358 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 359 |
+
raise ValueError(
|
| 360 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 361 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 362 |
+
f" {negative_prompt_embeds.shape}."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
@property
|
| 366 |
+
def guidance_scale(self):
|
| 367 |
+
return self._guidance_scale
|
| 368 |
+
|
| 369 |
+
@property
|
| 370 |
+
def num_timesteps(self):
|
| 371 |
+
return self._num_timesteps
|
| 372 |
+
|
| 373 |
+
@property
|
| 374 |
+
def attention_kwargs(self):
|
| 375 |
+
return self._attention_kwargs
|
| 376 |
+
|
| 377 |
+
@property
|
| 378 |
+
def interrupt(self):
|
| 379 |
+
return self._interrupt
|
| 380 |
+
|
| 381 |
+
@torch.no_grad()
|
| 382 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 383 |
+
def __call__(
|
| 384 |
+
self,
|
| 385 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 386 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 387 |
+
height: int = 480,
|
| 388 |
+
width: int = 720,
|
| 389 |
+
num_frames: int = 49,
|
| 390 |
+
num_inference_steps: int = 50,
|
| 391 |
+
timesteps: Optional[List[int]] = None,
|
| 392 |
+
guidance_scale: float = 6,
|
| 393 |
+
num_videos_per_prompt: int = 1,
|
| 394 |
+
eta: float = 0.0,
|
| 395 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 396 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 397 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 398 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 399 |
+
output_type: str = "numpy",
|
| 400 |
+
return_dict: bool = False,
|
| 401 |
+
callback_on_step_end: Optional[
|
| 402 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 403 |
+
] = None,
|
| 404 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 405 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 406 |
+
max_sequence_length: int = 512,
|
| 407 |
+
comfyui_progressbar: bool = False,
|
| 408 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 409 |
+
"""
|
| 410 |
+
Function invoked when calling the pipeline for generation.
|
| 411 |
+
Args:
|
| 412 |
+
|
| 413 |
+
Examples:
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 420 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 421 |
+
num_videos_per_prompt = 1
|
| 422 |
+
|
| 423 |
+
# 1. Check inputs. Raise error if not correct
|
| 424 |
+
self.check_inputs(
|
| 425 |
+
prompt,
|
| 426 |
+
height,
|
| 427 |
+
width,
|
| 428 |
+
negative_prompt,
|
| 429 |
+
callback_on_step_end_tensor_inputs,
|
| 430 |
+
prompt_embeds,
|
| 431 |
+
negative_prompt_embeds,
|
| 432 |
+
)
|
| 433 |
+
self._guidance_scale = guidance_scale
|
| 434 |
+
self._attention_kwargs = attention_kwargs
|
| 435 |
+
self._interrupt = False
|
| 436 |
+
|
| 437 |
+
# 2. Default call parameters
|
| 438 |
+
if prompt is not None and isinstance(prompt, str):
|
| 439 |
+
batch_size = 1
|
| 440 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 441 |
+
batch_size = len(prompt)
|
| 442 |
+
else:
|
| 443 |
+
batch_size = prompt_embeds.shape[0]
|
| 444 |
+
|
| 445 |
+
device = self._execution_device
|
| 446 |
+
weight_dtype = self.text_encoder.dtype
|
| 447 |
+
|
| 448 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 449 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 450 |
+
# corresponds to doing no classifier free guidance.
|
| 451 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 452 |
+
|
| 453 |
+
# 3. Encode input prompt
|
| 454 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 455 |
+
prompt,
|
| 456 |
+
negative_prompt,
|
| 457 |
+
do_classifier_free_guidance,
|
| 458 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 459 |
+
prompt_embeds=prompt_embeds,
|
| 460 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 461 |
+
max_sequence_length=max_sequence_length,
|
| 462 |
+
device=device,
|
| 463 |
+
)
|
| 464 |
+
if do_classifier_free_guidance:
|
| 465 |
+
prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 466 |
+
|
| 467 |
+
# 4. Prepare timesteps
|
| 468 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 469 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 470 |
+
else:
|
| 471 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 472 |
+
self._num_timesteps = len(timesteps)
|
| 473 |
+
if comfyui_progressbar:
|
| 474 |
+
from comfy.utils import ProgressBar
|
| 475 |
+
pbar = ProgressBar(num_inference_steps + 1)
|
| 476 |
+
|
| 477 |
+
# 5. Prepare latents
|
| 478 |
+
latent_channels = self.transformer.config.in_channels
|
| 479 |
+
latents = self.prepare_latents(
|
| 480 |
+
batch_size * num_videos_per_prompt,
|
| 481 |
+
latent_channels,
|
| 482 |
+
num_frames,
|
| 483 |
+
height,
|
| 484 |
+
width,
|
| 485 |
+
weight_dtype,
|
| 486 |
+
device,
|
| 487 |
+
generator,
|
| 488 |
+
latents,
|
| 489 |
+
)
|
| 490 |
+
if comfyui_progressbar:
|
| 491 |
+
pbar.update(1)
|
| 492 |
+
|
| 493 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 494 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 495 |
+
|
| 496 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
|
| 497 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 498 |
+
# 7. Denoising loop
|
| 499 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 500 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 501 |
+
for i, t in enumerate(timesteps):
|
| 502 |
+
if self.interrupt:
|
| 503 |
+
continue
|
| 504 |
+
|
| 505 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 506 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 507 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 508 |
+
|
| 509 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 510 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 511 |
+
|
| 512 |
+
# predict noise model_output
|
| 513 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype):
|
| 514 |
+
noise_pred = self.transformer(
|
| 515 |
+
x=latent_model_input,
|
| 516 |
+
context=prompt_embeds,
|
| 517 |
+
t=timestep,
|
| 518 |
+
seq_len=seq_len,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# perform guidance
|
| 522 |
+
if do_classifier_free_guidance:
|
| 523 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 524 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 525 |
+
|
| 526 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 527 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 528 |
+
|
| 529 |
+
if callback_on_step_end is not None:
|
| 530 |
+
callback_kwargs = {}
|
| 531 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 532 |
+
callback_kwargs[k] = locals()[k]
|
| 533 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 534 |
+
|
| 535 |
+
latents = callback_outputs.pop("latents", latents)
|
| 536 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 537 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 538 |
+
|
| 539 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 540 |
+
progress_bar.update()
|
| 541 |
+
if comfyui_progressbar:
|
| 542 |
+
pbar.update(1)
|
| 543 |
+
|
| 544 |
+
if output_type == "numpy":
|
| 545 |
+
video = self.decode_latents(latents)
|
| 546 |
+
elif not output_type == "latent":
|
| 547 |
+
video = self.decode_latents(latents)
|
| 548 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 549 |
+
else:
|
| 550 |
+
video = latents
|
| 551 |
+
|
| 552 |
+
# Offload all models
|
| 553 |
+
self.maybe_free_model_hooks()
|
| 554 |
+
|
| 555 |
+
if not return_dict:
|
| 556 |
+
video = torch.from_numpy(video)
|
| 557 |
+
|
| 558 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/reward/MPS/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This folder is modified from the official [MPS](https://github.com/Kwai-Kolors/MPS/tree/main) repository.
|
videox_fun/reward/MPS/trainer/models/base_model.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class BaseModelConfig:
|
| 7 |
+
pass
|
videox_fun/reward/MPS/trainer/models/clip_model.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from transformers import CLIPModel as HFCLIPModel
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
|
| 5 |
+
from torch import nn, einsum
|
| 6 |
+
|
| 7 |
+
# Modified: import
|
| 8 |
+
# from trainer.models.base_model import BaseModelConfig
|
| 9 |
+
from .base_model import BaseModelConfig
|
| 10 |
+
|
| 11 |
+
from transformers import CLIPConfig
|
| 12 |
+
from typing import Any, Optional, Tuple, Union
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# Modified: import
|
| 16 |
+
# from trainer.models.cross_modeling import Cross_model
|
| 17 |
+
from .cross_modeling import Cross_model
|
| 18 |
+
|
| 19 |
+
import gc
|
| 20 |
+
|
| 21 |
+
class XCLIPModel(HFCLIPModel):
|
| 22 |
+
def __init__(self, config: CLIPConfig):
|
| 23 |
+
super().__init__(config)
|
| 24 |
+
|
| 25 |
+
def get_text_features(
|
| 26 |
+
self,
|
| 27 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 28 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 29 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 30 |
+
output_attentions: Optional[bool] = None,
|
| 31 |
+
output_hidden_states: Optional[bool] = None,
|
| 32 |
+
return_dict: Optional[bool] = None,
|
| 33 |
+
) -> torch.FloatTensor:
|
| 34 |
+
|
| 35 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
| 36 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 37 |
+
output_hidden_states = (
|
| 38 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 39 |
+
)
|
| 40 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 41 |
+
|
| 42 |
+
text_outputs = self.text_model(
|
| 43 |
+
input_ids=input_ids,
|
| 44 |
+
attention_mask=attention_mask,
|
| 45 |
+
position_ids=position_ids,
|
| 46 |
+
output_attentions=output_attentions,
|
| 47 |
+
output_hidden_states=output_hidden_states,
|
| 48 |
+
return_dict=return_dict,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# pooled_output = text_outputs[1]
|
| 52 |
+
# text_features = self.text_projection(pooled_output)
|
| 53 |
+
last_hidden_state = text_outputs[0]
|
| 54 |
+
text_features = self.text_projection(last_hidden_state)
|
| 55 |
+
|
| 56 |
+
pooled_output = text_outputs[1]
|
| 57 |
+
text_features_EOS = self.text_projection(pooled_output)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# del last_hidden_state, text_outputs
|
| 61 |
+
# gc.collect()
|
| 62 |
+
|
| 63 |
+
return text_features, text_features_EOS
|
| 64 |
+
|
| 65 |
+
def get_image_features(
|
| 66 |
+
self,
|
| 67 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 68 |
+
output_attentions: Optional[bool] = None,
|
| 69 |
+
output_hidden_states: Optional[bool] = None,
|
| 70 |
+
return_dict: Optional[bool] = None,
|
| 71 |
+
) -> torch.FloatTensor:
|
| 72 |
+
|
| 73 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
| 74 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 75 |
+
output_hidden_states = (
|
| 76 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 77 |
+
)
|
| 78 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 79 |
+
|
| 80 |
+
vision_outputs = self.vision_model(
|
| 81 |
+
pixel_values=pixel_values,
|
| 82 |
+
output_attentions=output_attentions,
|
| 83 |
+
output_hidden_states=output_hidden_states,
|
| 84 |
+
return_dict=return_dict,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# pooled_output = vision_outputs[1] # pooled_output
|
| 88 |
+
# image_features = self.visual_projection(pooled_output)
|
| 89 |
+
last_hidden_state = vision_outputs[0]
|
| 90 |
+
image_features = self.visual_projection(last_hidden_state)
|
| 91 |
+
|
| 92 |
+
return image_features
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass
|
| 97 |
+
class ClipModelConfig(BaseModelConfig):
|
| 98 |
+
_target_: str = "trainer.models.clip_model.CLIPModel"
|
| 99 |
+
pretrained_model_name_or_path: str ="openai/clip-vit-base-patch32"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class CLIPModel(nn.Module):
|
| 103 |
+
def __init__(self, config):
|
| 104 |
+
super().__init__()
|
| 105 |
+
# Modified: We convert the original ckpt (contains the entire model) to a `state_dict`.
|
| 106 |
+
# self.model = XCLIPModel.from_pretrained(ckpt)
|
| 107 |
+
self.model = XCLIPModel(config)
|
| 108 |
+
self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
|
| 109 |
+
|
| 110 |
+
def get_text_features(self, *args, **kwargs):
|
| 111 |
+
return self.model.get_text_features(*args, **kwargs)
|
| 112 |
+
|
| 113 |
+
def get_image_features(self, *args, **kwargs):
|
| 114 |
+
return self.model.get_image_features(*args, **kwargs)
|
| 115 |
+
|
| 116 |
+
def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
|
| 117 |
+
outputs = ()
|
| 118 |
+
|
| 119 |
+
text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
|
| 120 |
+
outputs += text_EOS,
|
| 121 |
+
|
| 122 |
+
image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
|
| 123 |
+
# [B, 77, 1024]
|
| 124 |
+
condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
|
| 125 |
+
|
| 126 |
+
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
| 127 |
+
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
| 128 |
+
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
| 129 |
+
mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
|
| 130 |
+
|
| 131 |
+
# Modified: Support both torch.float16 and torch.bfloat16
|
| 132 |
+
# mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
|
| 133 |
+
model_dtype = next(self.cross_model.parameters()).dtype
|
| 134 |
+
mask = mask.repeat(1,image_f.shape[1],1).to(model_dtype) # B*257*77
|
| 135 |
+
# bc = int(image_f.shape[0]/2)
|
| 136 |
+
|
| 137 |
+
# Modified: The original input consists of a (batch of) text and two (batches of) images,
|
| 138 |
+
# primarily used to compute which (batch of) image is more consistent with the text.
|
| 139 |
+
# The modified input consists of a (batch of) text and a (batch of) images.
|
| 140 |
+
# sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
|
| 141 |
+
# sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
|
| 142 |
+
# outputs += sim0[:,0,:],
|
| 143 |
+
# outputs += sim1[:,0,:],
|
| 144 |
+
sim = self.cross_model(image_f, text_f,mask)
|
| 145 |
+
outputs += sim[:,0,:],
|
| 146 |
+
|
| 147 |
+
return outputs
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def logit_scale(self):
|
| 151 |
+
return self.model.logit_scale
|
| 152 |
+
|
| 153 |
+
def save(self, path):
|
| 154 |
+
self.model.save_pretrained(path)
|
videox_fun/reward/MPS/trainer/models/cross_modeling.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import einsum, nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
|
| 6 |
+
# helper functions
|
| 7 |
+
|
| 8 |
+
def exists(val):
|
| 9 |
+
return val is not None
|
| 10 |
+
|
| 11 |
+
def default(val, d):
|
| 12 |
+
return val if exists(val) else d
|
| 13 |
+
|
| 14 |
+
# normalization
|
| 15 |
+
# they use layernorm without bias, something that pytorch does not offer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LayerNorm(nn.Module):
|
| 19 |
+
def __init__(self, dim):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 22 |
+
self.register_buffer("bias", torch.zeros(dim))
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
|
| 26 |
+
|
| 27 |
+
# residual
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Residual(nn.Module):
|
| 31 |
+
def __init__(self, fn):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.fn = fn
|
| 34 |
+
|
| 35 |
+
def forward(self, x, *args, **kwargs):
|
| 36 |
+
return self.fn(x, *args, **kwargs) + x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# rotary positional embedding
|
| 40 |
+
# https://arxiv.org/abs/2104.09864
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RotaryEmbedding(nn.Module):
|
| 44 |
+
def __init__(self, dim):
|
| 45 |
+
super().__init__()
|
| 46 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| 47 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 48 |
+
|
| 49 |
+
def forward(self, max_seq_len, *, device):
|
| 50 |
+
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
|
| 51 |
+
freqs = einsum("i , j -> i j", seq, self.inv_freq)
|
| 52 |
+
return torch.cat((freqs, freqs), dim=-1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def rotate_half(x):
|
| 56 |
+
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
| 57 |
+
x1, x2 = x.unbind(dim=-2)
|
| 58 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def apply_rotary_pos_emb(pos, t):
|
| 62 |
+
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
|
| 66 |
+
# https://arxiv.org/abs/2002.05202
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SwiGLU(nn.Module):
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
x, gate = x.chunk(2, dim=-1)
|
| 72 |
+
return F.silu(gate) * x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# parallel attention and feedforward with residual
|
| 76 |
+
# discovered by Wang et al + EleutherAI from GPT-J fame
|
| 77 |
+
|
| 78 |
+
class ParallelTransformerBlock(nn.Module):
|
| 79 |
+
def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.norm = LayerNorm(dim)
|
| 82 |
+
|
| 83 |
+
attn_inner_dim = dim_head * heads
|
| 84 |
+
ff_inner_dim = dim * ff_mult
|
| 85 |
+
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
| 86 |
+
|
| 87 |
+
self.heads = heads
|
| 88 |
+
self.scale = dim_head**-0.5
|
| 89 |
+
self.rotary_emb = RotaryEmbedding(dim_head)
|
| 90 |
+
|
| 91 |
+
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
|
| 92 |
+
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
|
| 93 |
+
|
| 94 |
+
self.ff_out = nn.Sequential(
|
| 95 |
+
SwiGLU(),
|
| 96 |
+
nn.Linear(ff_inner_dim, dim, bias=False)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.register_buffer("pos_emb", None, persistent=False)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_rotary_embedding(self, n, device):
|
| 103 |
+
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
| 104 |
+
return self.pos_emb[:n]
|
| 105 |
+
|
| 106 |
+
pos_emb = self.rotary_emb(n, device=device)
|
| 107 |
+
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
| 108 |
+
return pos_emb
|
| 109 |
+
|
| 110 |
+
def forward(self, x, attn_mask=None):
|
| 111 |
+
"""
|
| 112 |
+
einstein notation
|
| 113 |
+
b - batch
|
| 114 |
+
h - heads
|
| 115 |
+
n, i, j - sequence length (base sequence length, source, target)
|
| 116 |
+
d - feature dimension
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
n, device, h = x.shape[1], x.device, self.heads
|
| 120 |
+
|
| 121 |
+
# pre layernorm
|
| 122 |
+
|
| 123 |
+
x = self.norm(x)
|
| 124 |
+
|
| 125 |
+
# attention queries, keys, values, and feedforward inner
|
| 126 |
+
|
| 127 |
+
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
|
| 128 |
+
|
| 129 |
+
# split heads
|
| 130 |
+
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
| 131 |
+
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
| 132 |
+
# https://arxiv.org/abs/1911.02150
|
| 133 |
+
|
| 134 |
+
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
| 135 |
+
|
| 136 |
+
# rotary embeddings
|
| 137 |
+
|
| 138 |
+
positions = self.get_rotary_embedding(n, device)
|
| 139 |
+
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
| 140 |
+
|
| 141 |
+
# scale
|
| 142 |
+
|
| 143 |
+
q = q * self.scale
|
| 144 |
+
|
| 145 |
+
# similarity
|
| 146 |
+
|
| 147 |
+
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# extra attention mask - for masking out attention from text CLS token to padding
|
| 151 |
+
|
| 152 |
+
if exists(attn_mask):
|
| 153 |
+
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
|
| 154 |
+
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
|
| 155 |
+
|
| 156 |
+
# attention
|
| 157 |
+
|
| 158 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 159 |
+
attn = sim.softmax(dim=-1)
|
| 160 |
+
|
| 161 |
+
# aggregate values
|
| 162 |
+
|
| 163 |
+
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
| 164 |
+
|
| 165 |
+
# merge heads
|
| 166 |
+
|
| 167 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 168 |
+
return self.attn_out(out) + self.ff_out(ff)
|
| 169 |
+
|
| 170 |
+
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
|
| 171 |
+
|
| 172 |
+
class CrossAttention(nn.Module):
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
dim,
|
| 176 |
+
*,
|
| 177 |
+
context_dim=None,
|
| 178 |
+
dim_head=64,
|
| 179 |
+
heads=12,
|
| 180 |
+
parallel_ff=False,
|
| 181 |
+
ff_mult=4,
|
| 182 |
+
norm_context=False
|
| 183 |
+
):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.heads = heads
|
| 186 |
+
self.scale = dim_head ** -0.5
|
| 187 |
+
inner_dim = heads * dim_head
|
| 188 |
+
context_dim = default(context_dim, dim)
|
| 189 |
+
|
| 190 |
+
self.norm = LayerNorm(dim)
|
| 191 |
+
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
|
| 192 |
+
|
| 193 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 194 |
+
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
|
| 195 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 196 |
+
|
| 197 |
+
# whether to have parallel feedforward
|
| 198 |
+
|
| 199 |
+
ff_inner_dim = ff_mult * dim
|
| 200 |
+
|
| 201 |
+
self.ff = nn.Sequential(
|
| 202 |
+
nn.Linear(dim, ff_inner_dim * 2, bias=False),
|
| 203 |
+
SwiGLU(),
|
| 204 |
+
nn.Linear(ff_inner_dim, dim, bias=False)
|
| 205 |
+
) if parallel_ff else None
|
| 206 |
+
|
| 207 |
+
def forward(self, x, context, mask):
|
| 208 |
+
"""
|
| 209 |
+
einstein notation
|
| 210 |
+
b - batch
|
| 211 |
+
h - heads
|
| 212 |
+
n, i, j - sequence length (base sequence length, source, target)
|
| 213 |
+
d - feature dimension
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
# pre-layernorm, for queries and context
|
| 217 |
+
|
| 218 |
+
x = self.norm(x)
|
| 219 |
+
context = self.context_norm(context)
|
| 220 |
+
|
| 221 |
+
# get queries
|
| 222 |
+
|
| 223 |
+
q = self.to_q(x)
|
| 224 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
| 225 |
+
|
| 226 |
+
# scale
|
| 227 |
+
|
| 228 |
+
q = q * self.scale
|
| 229 |
+
|
| 230 |
+
# get key / values
|
| 231 |
+
|
| 232 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
| 233 |
+
|
| 234 |
+
# query / key similarity
|
| 235 |
+
|
| 236 |
+
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
| 237 |
+
|
| 238 |
+
# attention
|
| 239 |
+
mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
|
| 240 |
+
sim = sim + mask # context mask
|
| 241 |
+
sim = sim - sim.amax(dim=-1, keepdim=True)
|
| 242 |
+
attn = sim.softmax(dim=-1)
|
| 243 |
+
|
| 244 |
+
# aggregate
|
| 245 |
+
|
| 246 |
+
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
| 247 |
+
|
| 248 |
+
# merge and combine heads
|
| 249 |
+
|
| 250 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 251 |
+
out = self.to_out(out)
|
| 252 |
+
|
| 253 |
+
# add parallel feedforward (for multimodal layers)
|
| 254 |
+
|
| 255 |
+
if exists(self.ff):
|
| 256 |
+
out = out + self.ff(x)
|
| 257 |
+
|
| 258 |
+
return out
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class Cross_model(nn.Module):
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
dim=512,
|
| 265 |
+
layer_num=4,
|
| 266 |
+
dim_head=64,
|
| 267 |
+
heads=8,
|
| 268 |
+
ff_mult=4
|
| 269 |
+
):
|
| 270 |
+
super().__init__()
|
| 271 |
+
|
| 272 |
+
self.layers = nn.ModuleList([])
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
for ind in range(layer_num):
|
| 276 |
+
self.layers.append(nn.ModuleList([
|
| 277 |
+
Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
|
| 278 |
+
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
|
| 279 |
+
]))
|
| 280 |
+
|
| 281 |
+
def forward(
|
| 282 |
+
self,
|
| 283 |
+
query_tokens,
|
| 284 |
+
context_tokens,
|
| 285 |
+
mask
|
| 286 |
+
):
|
| 287 |
+
for cross_attn, self_attn_ff in self.layers:
|
| 288 |
+
query_tokens = cross_attn(query_tokens, context_tokens,mask)
|
| 289 |
+
query_tokens = self_attn_ff(query_tokens)
|
| 290 |
+
|
| 291 |
+
return query_tokens
|
videox_fun/reward/aesthetic_predictor_v2_5/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .siglip_v2_5 import (
|
| 2 |
+
AestheticPredictorV2_5Head,
|
| 3 |
+
AestheticPredictorV2_5Model,
|
| 4 |
+
AestheticPredictorV2_5Processor,
|
| 5 |
+
convert_v2_5_from_siglip,
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"AestheticPredictorV2_5Head",
|
| 10 |
+
"AestheticPredictorV2_5Model",
|
| 11 |
+
"AestheticPredictorV2_5Processor",
|
| 12 |
+
"convert_v2_5_from_siglip",
|
| 13 |
+
]
|
videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py
|
| 2 |
+
import os
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from os import PathLike
|
| 5 |
+
from typing import Final
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from transformers import (
|
| 11 |
+
SiglipImageProcessor,
|
| 12 |
+
SiglipVisionConfig,
|
| 13 |
+
SiglipVisionModel,
|
| 14 |
+
logging,
|
| 15 |
+
)
|
| 16 |
+
from transformers.image_processing_utils import BatchFeature
|
| 17 |
+
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
|
| 18 |
+
|
| 19 |
+
logging.set_verbosity_error()
|
| 20 |
+
|
| 21 |
+
URL: Final[str] = (
|
| 22 |
+
"https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AestheticPredictorV2_5Head(nn.Module):
|
| 27 |
+
def __init__(self, config: SiglipVisionConfig) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.scoring_head = nn.Sequential(
|
| 30 |
+
nn.Linear(config.hidden_size, 1024),
|
| 31 |
+
nn.Dropout(0.5),
|
| 32 |
+
nn.Linear(1024, 128),
|
| 33 |
+
nn.Dropout(0.5),
|
| 34 |
+
nn.Linear(128, 64),
|
| 35 |
+
nn.Dropout(0.5),
|
| 36 |
+
nn.Linear(64, 16),
|
| 37 |
+
nn.Dropout(0.2),
|
| 38 |
+
nn.Linear(16, 1),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
return self.scoring_head(image_embeds)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AestheticPredictorV2_5Model(SiglipVisionModel):
|
| 46 |
+
PATCH_SIZE = 14
|
| 47 |
+
|
| 48 |
+
def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None:
|
| 49 |
+
super().__init__(config, *args, **kwargs)
|
| 50 |
+
self.layers = AestheticPredictorV2_5Head(config)
|
| 51 |
+
self.post_init()
|
| 52 |
+
self.transforms = transforms.Compose([
|
| 53 |
+
transforms.Resize((384, 384)),
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
def forward(
|
| 59 |
+
self,
|
| 60 |
+
pixel_values: torch.FloatTensor | None = None,
|
| 61 |
+
labels: torch.Tensor | None = None,
|
| 62 |
+
return_dict: bool | None = None,
|
| 63 |
+
) -> tuple | ImageClassifierOutputWithNoAttention:
|
| 64 |
+
return_dict = (
|
| 65 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
outputs = super().forward(
|
| 69 |
+
pixel_values=pixel_values,
|
| 70 |
+
return_dict=return_dict,
|
| 71 |
+
)
|
| 72 |
+
image_embeds = outputs.pooler_output
|
| 73 |
+
image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
| 74 |
+
prediction = self.layers(image_embeds_norm)
|
| 75 |
+
|
| 76 |
+
loss = None
|
| 77 |
+
if labels is not None:
|
| 78 |
+
loss_fct = nn.MSELoss()
|
| 79 |
+
loss = loss_fct()
|
| 80 |
+
|
| 81 |
+
if not return_dict:
|
| 82 |
+
return (loss, prediction, image_embeds)
|
| 83 |
+
|
| 84 |
+
return ImageClassifierOutputWithNoAttention(
|
| 85 |
+
loss=loss,
|
| 86 |
+
logits=prediction,
|
| 87 |
+
hidden_states=image_embeds,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class AestheticPredictorV2_5Processor(SiglipImageProcessor):
|
| 92 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 93 |
+
super().__init__(*args, **kwargs)
|
| 94 |
+
|
| 95 |
+
def __call__(self, *args, **kwargs) -> BatchFeature:
|
| 96 |
+
return super().__call__(*args, **kwargs)
|
| 97 |
+
|
| 98 |
+
@classmethod
|
| 99 |
+
def from_pretrained(
|
| 100 |
+
self,
|
| 101 |
+
pretrained_model_name_or_path: str
|
| 102 |
+
| PathLike = "google/siglip-so400m-patch14-384",
|
| 103 |
+
*args,
|
| 104 |
+
**kwargs,
|
| 105 |
+
) -> "AestheticPredictorV2_5Processor":
|
| 106 |
+
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def convert_v2_5_from_siglip(
|
| 110 |
+
predictor_name_or_path: str | PathLike | None = None,
|
| 111 |
+
encoder_model_name: str = "google/siglip-so400m-patch14-384",
|
| 112 |
+
*args,
|
| 113 |
+
**kwargs,
|
| 114 |
+
) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]:
|
| 115 |
+
model = AestheticPredictorV2_5Model.from_pretrained(
|
| 116 |
+
encoder_model_name, *args, **kwargs
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
processor = AestheticPredictorV2_5Processor.from_pretrained(
|
| 120 |
+
encoder_model_name, *args, **kwargs
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path):
|
| 124 |
+
state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu")
|
| 125 |
+
else:
|
| 126 |
+
state_dict = torch.load(predictor_name_or_path, map_location="cpu")
|
| 127 |
+
|
| 128 |
+
assert isinstance(state_dict, OrderedDict)
|
| 129 |
+
|
| 130 |
+
model.layers.load_state_dict(state_dict)
|
| 131 |
+
model.eval()
|
| 132 |
+
|
| 133 |
+
return model, processor
|
videox_fun/reward/improved_aesthetic_predictor.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import CLIPModel
|
| 6 |
+
from torchvision.datasets.utils import download_url
|
| 7 |
+
|
| 8 |
+
URL = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/sac%2Blogos%2Bava1-l14-linearMSE.pth"
|
| 9 |
+
FILENAME = "sac+logos+ava1-l14-linearMSE.pth"
|
| 10 |
+
MD5 = "b1047fd767a00134b8fd6529bf19521a"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MLP(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.layers = nn.Sequential(
|
| 17 |
+
nn.Linear(768, 1024),
|
| 18 |
+
nn.Dropout(0.2),
|
| 19 |
+
nn.Linear(1024, 128),
|
| 20 |
+
nn.Dropout(0.2),
|
| 21 |
+
nn.Linear(128, 64),
|
| 22 |
+
nn.Dropout(0.1),
|
| 23 |
+
nn.Linear(64, 16),
|
| 24 |
+
nn.Linear(16, 1),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def forward(self, embed):
|
| 29 |
+
return self.layers(embed)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ImprovedAestheticPredictor(nn.Module):
|
| 33 |
+
def __init__(self, encoder_path="openai/clip-vit-large-patch14", predictor_path=None):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.encoder = CLIPModel.from_pretrained(encoder_path)
|
| 36 |
+
self.predictor = MLP()
|
| 37 |
+
if predictor_path is None or not os.path.exists(predictor_path):
|
| 38 |
+
download_url(URL, torch.hub.get_dir(), FILENAME, md5=MD5)
|
| 39 |
+
predictor_path = os.path.join(torch.hub.get_dir(), FILENAME)
|
| 40 |
+
state_dict = torch.load(predictor_path, map_location="cpu")
|
| 41 |
+
self.predictor.load_state_dict(state_dict)
|
| 42 |
+
self.eval()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def forward(self, pixel_values):
|
| 46 |
+
embed = self.encoder.get_image_features(pixel_values=pixel_values)
|
| 47 |
+
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
|
| 48 |
+
|
| 49 |
+
return self.predictor(embed).squeeze(1)
|
videox_fun/reward/reward_fn.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torchvision.datasets.utils import download_url
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# All reward models.
|
| 12 |
+
__all__ = ["AestheticReward", "HPSReward", "PickScoreReward", "MPSReward"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseReward(ABC):
|
| 16 |
+
"""An base class for reward models. A custom Reward class must implement two functions below.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self):
|
| 19 |
+
"""Define your reward model and image transformations (optional) here.
|
| 20 |
+
"""
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 25 |
+
"""Given batch frames with shape `[B, C, T, H, W]` extracted from a list of videos and a list of prompts
|
| 26 |
+
(optional) correspondingly, return the loss and reward computed by your reward model (reduction by mean).
|
| 27 |
+
"""
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
class AestheticReward(BaseReward):
|
| 31 |
+
"""Aesthetic Predictor [V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)
|
| 32 |
+
and [V2.5](https://github.com/discus0434/aesthetic-predictor-v2-5) reward model.
|
| 33 |
+
"""
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
encoder_path="openai/clip-vit-large-patch14",
|
| 37 |
+
predictor_path=None,
|
| 38 |
+
version="v2",
|
| 39 |
+
device="cpu",
|
| 40 |
+
dtype=torch.float16,
|
| 41 |
+
max_reward=10,
|
| 42 |
+
loss_scale=0.1,
|
| 43 |
+
):
|
| 44 |
+
from .improved_aesthetic_predictor import ImprovedAestheticPredictor
|
| 45 |
+
from ..video_caption.utils.siglip_v2_5 import convert_v2_5_from_siglip
|
| 46 |
+
|
| 47 |
+
self.encoder_path = encoder_path
|
| 48 |
+
self.predictor_path = predictor_path
|
| 49 |
+
self.version = version
|
| 50 |
+
self.device = device
|
| 51 |
+
self.dtype = dtype
|
| 52 |
+
self.max_reward = max_reward
|
| 53 |
+
self.loss_scale = loss_scale
|
| 54 |
+
|
| 55 |
+
if self.version != "v2" and self.version != "v2.5":
|
| 56 |
+
raise ValueError("Only v2 and v2.5 are supported.")
|
| 57 |
+
if self.version == "v2":
|
| 58 |
+
assert "clip-vit-large-patch14" in encoder_path.lower()
|
| 59 |
+
self.model = ImprovedAestheticPredictor(encoder_path=self.encoder_path, predictor_path=self.predictor_path)
|
| 60 |
+
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/preprocessor_config.json
|
| 61 |
+
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
|
| 62 |
+
self.transform = transforms.Compose([
|
| 63 |
+
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
|
| 64 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
|
| 65 |
+
])
|
| 66 |
+
elif self.version == "v2.5":
|
| 67 |
+
assert "siglip-so400m-patch14-384" in encoder_path.lower()
|
| 68 |
+
self.model, _ = convert_v2_5_from_siglip(encoder_model_name=self.encoder_path)
|
| 69 |
+
# https://huggingface.co/google/siglip-so400m-patch14-384/blob/main/preprocessor_config.json
|
| 70 |
+
self.transform = transforms.Compose([
|
| 71 |
+
transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC),
|
| 72 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 73 |
+
])
|
| 74 |
+
|
| 75 |
+
self.model.to(device=self.device, dtype=self.dtype)
|
| 76 |
+
self.model.requires_grad_(False)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 80 |
+
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
|
| 81 |
+
batch_loss, batch_reward = 0, 0
|
| 82 |
+
for frames in batch_frames:
|
| 83 |
+
pixel_values = torch.stack([self.transform(frame) for frame in frames])
|
| 84 |
+
pixel_values = pixel_values.to(self.device, dtype=self.dtype)
|
| 85 |
+
if self.version == "v2":
|
| 86 |
+
reward = self.model(pixel_values)
|
| 87 |
+
elif self.version == "v2.5":
|
| 88 |
+
reward = self.model(pixel_values).logits.squeeze()
|
| 89 |
+
# Convert reward to loss in [0, 1].
|
| 90 |
+
if self.max_reward is None:
|
| 91 |
+
loss = (-1 * reward) * self.loss_scale
|
| 92 |
+
else:
|
| 93 |
+
loss = abs(reward - self.max_reward) * self.loss_scale
|
| 94 |
+
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
|
| 95 |
+
|
| 96 |
+
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class HPSReward(BaseReward):
|
| 100 |
+
"""[HPS](https://github.com/tgxs002/HPSv2) v2 and v2.1 reward model.
|
| 101 |
+
"""
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
model_path=None,
|
| 105 |
+
version="v2.0",
|
| 106 |
+
device="cpu",
|
| 107 |
+
dtype=torch.float16,
|
| 108 |
+
max_reward=1,
|
| 109 |
+
loss_scale=1,
|
| 110 |
+
):
|
| 111 |
+
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
| 112 |
+
|
| 113 |
+
self.model_path = model_path
|
| 114 |
+
self.version = version
|
| 115 |
+
self.device = device
|
| 116 |
+
self.dtype = dtype
|
| 117 |
+
self.max_reward = max_reward
|
| 118 |
+
self.loss_scale = loss_scale
|
| 119 |
+
|
| 120 |
+
self.model, _, _ = create_model_and_transforms(
|
| 121 |
+
"ViT-H-14",
|
| 122 |
+
"laion2B-s32B-b79K",
|
| 123 |
+
precision=self.dtype,
|
| 124 |
+
device=self.device,
|
| 125 |
+
jit=False,
|
| 126 |
+
force_quick_gelu=False,
|
| 127 |
+
force_custom_text=False,
|
| 128 |
+
force_patch_dropout=False,
|
| 129 |
+
force_image_size=None,
|
| 130 |
+
pretrained_image=False,
|
| 131 |
+
image_mean=None,
|
| 132 |
+
image_std=None,
|
| 133 |
+
light_augmentation=True,
|
| 134 |
+
aug_cfg={},
|
| 135 |
+
output_dict=True,
|
| 136 |
+
with_score_predictor=False,
|
| 137 |
+
with_region_predictor=False,
|
| 138 |
+
)
|
| 139 |
+
self.tokenizer = get_tokenizer("ViT-H-14")
|
| 140 |
+
|
| 141 |
+
# https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json
|
| 142 |
+
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
|
| 143 |
+
self.transform = transforms.Compose([
|
| 144 |
+
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
|
| 145 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
|
| 146 |
+
])
|
| 147 |
+
|
| 148 |
+
if version == "v2.0":
|
| 149 |
+
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2_compressed.pt"
|
| 150 |
+
filename = "HPS_v2_compressed.pt"
|
| 151 |
+
md5 = "fd9180de357abf01fdb4eaad64631db4"
|
| 152 |
+
elif version == "v2.1":
|
| 153 |
+
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2.1_compressed.pt"
|
| 154 |
+
filename = "HPS_v2.1_compressed.pt"
|
| 155 |
+
md5 = "4067542e34ba2553a738c5ac6c1d75c0"
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError("Only v2.0 and v2.1 are supported.")
|
| 158 |
+
if self.model_path is None or not os.path.exists(self.model_path):
|
| 159 |
+
download_url(url, torch.hub.get_dir(), md5=md5)
|
| 160 |
+
model_path = os.path.join(torch.hub.get_dir(), filename)
|
| 161 |
+
|
| 162 |
+
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
|
| 163 |
+
self.model.load_state_dict(state_dict)
|
| 164 |
+
self.model.to(device=self.device, dtype=self.dtype)
|
| 165 |
+
self.model.requires_grad_(False)
|
| 166 |
+
self.model.eval()
|
| 167 |
+
|
| 168 |
+
def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 169 |
+
assert batch_frames.shape[0] == len(batch_prompt)
|
| 170 |
+
# Compute batch reward and loss in frame-wise.
|
| 171 |
+
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
|
| 172 |
+
batch_loss, batch_reward = 0, 0
|
| 173 |
+
for frames in batch_frames:
|
| 174 |
+
image_inputs = torch.stack([self.transform(frame) for frame in frames])
|
| 175 |
+
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
|
| 176 |
+
text_inputs = self.tokenizer(batch_prompt).to(device=self.device)
|
| 177 |
+
outputs = self.model(image_inputs, text_inputs)
|
| 178 |
+
|
| 179 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 180 |
+
logits = image_features @ text_features.T
|
| 181 |
+
reward = torch.diagonal(logits)
|
| 182 |
+
# Convert reward to loss in [0, 1].
|
| 183 |
+
if self.max_reward is None:
|
| 184 |
+
loss = (-1 * reward) * self.loss_scale
|
| 185 |
+
else:
|
| 186 |
+
loss = abs(reward - self.max_reward) * self.loss_scale
|
| 187 |
+
|
| 188 |
+
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
|
| 189 |
+
|
| 190 |
+
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class PickScoreReward(BaseReward):
|
| 194 |
+
"""[PickScore](https://github.com/yuvalkirstain/PickScore) reward model.
|
| 195 |
+
"""
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
model_path="yuvalkirstain/PickScore_v1",
|
| 199 |
+
device="cpu",
|
| 200 |
+
dtype=torch.float16,
|
| 201 |
+
max_reward=1,
|
| 202 |
+
loss_scale=1,
|
| 203 |
+
):
|
| 204 |
+
from transformers import AutoProcessor, AutoModel
|
| 205 |
+
|
| 206 |
+
self.model_path = model_path
|
| 207 |
+
self.device = device
|
| 208 |
+
self.dtype = dtype
|
| 209 |
+
self.max_reward = max_reward
|
| 210 |
+
self.loss_scale = loss_scale
|
| 211 |
+
|
| 212 |
+
# https://huggingface.co/yuvalkirstain/PickScore_v1/blob/main/preprocessor_config.json
|
| 213 |
+
self.transform = transforms.Compose([
|
| 214 |
+
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 215 |
+
transforms.CenterCrop(224),
|
| 216 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
|
| 217 |
+
])
|
| 218 |
+
self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype)
|
| 219 |
+
self.model = AutoModel.from_pretrained(model_path, torch_dtype=self.dtype).eval().to(device)
|
| 220 |
+
self.model.requires_grad_(False)
|
| 221 |
+
self.model.eval()
|
| 222 |
+
|
| 223 |
+
def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 224 |
+
assert batch_frames.shape[0] == len(batch_prompt)
|
| 225 |
+
# Compute batch reward and loss in frame-wise.
|
| 226 |
+
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
|
| 227 |
+
batch_loss, batch_reward = 0, 0
|
| 228 |
+
for frames in batch_frames:
|
| 229 |
+
image_inputs = torch.stack([self.transform(frame) for frame in frames])
|
| 230 |
+
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
|
| 231 |
+
text_inputs = self.processor(
|
| 232 |
+
text=batch_prompt,
|
| 233 |
+
padding=True,
|
| 234 |
+
truncation=True,
|
| 235 |
+
max_length=77,
|
| 236 |
+
return_tensors="pt",
|
| 237 |
+
).to(self.device)
|
| 238 |
+
image_features = self.model.get_image_features(pixel_values=image_inputs)
|
| 239 |
+
text_features = self.model.get_text_features(**text_inputs)
|
| 240 |
+
image_features = image_features / torch.norm(image_features, dim=-1, keepdim=True)
|
| 241 |
+
text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True)
|
| 242 |
+
|
| 243 |
+
logits = image_features @ text_features.T
|
| 244 |
+
reward = torch.diagonal(logits)
|
| 245 |
+
# Convert reward to loss in [0, 1].
|
| 246 |
+
if self.max_reward is None:
|
| 247 |
+
loss = (-1 * reward) * self.loss_scale
|
| 248 |
+
else:
|
| 249 |
+
loss = abs(reward - self.max_reward) * self.loss_scale
|
| 250 |
+
|
| 251 |
+
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
|
| 252 |
+
|
| 253 |
+
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class MPSReward(BaseReward):
|
| 257 |
+
"""[MPS](https://github.com/Kwai-Kolors/MPS) reward model.
|
| 258 |
+
"""
|
| 259 |
+
def __init__(
|
| 260 |
+
self,
|
| 261 |
+
model_path=None,
|
| 262 |
+
device="cpu",
|
| 263 |
+
dtype=torch.float16,
|
| 264 |
+
max_reward=1,
|
| 265 |
+
loss_scale=1,
|
| 266 |
+
):
|
| 267 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 268 |
+
from .MPS.trainer.models.clip_model import CLIPModel
|
| 269 |
+
|
| 270 |
+
self.model_path = model_path
|
| 271 |
+
self.device = device
|
| 272 |
+
self.dtype = dtype
|
| 273 |
+
self.condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things."
|
| 274 |
+
self.max_reward = max_reward
|
| 275 |
+
self.loss_scale = loss_scale
|
| 276 |
+
|
| 277 |
+
processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
| 278 |
+
# https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json
|
| 279 |
+
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
|
| 280 |
+
self.transform = transforms.Compose([
|
| 281 |
+
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
|
| 282 |
+
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
|
| 283 |
+
])
|
| 284 |
+
|
| 285 |
+
# We convert the original [ckpt](http://drive.google.com/file/d/17qrK_aJkVNM75ZEvMEePpLj6L867MLkN/view?usp=sharing)
|
| 286 |
+
# (contains the entire model) to a `state_dict`.
|
| 287 |
+
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/MPS_overall.pth"
|
| 288 |
+
filename = "MPS_overall.pth"
|
| 289 |
+
md5 = "1491cbbbd20565747fe07e7572e2ac56"
|
| 290 |
+
if self.model_path is None or not os.path.exists(self.model_path):
|
| 291 |
+
download_url(url, torch.hub.get_dir(), md5=md5)
|
| 292 |
+
model_path = os.path.join(torch.hub.get_dir(), filename)
|
| 293 |
+
|
| 294 |
+
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
| 295 |
+
config = AutoConfig.from_pretrained(processor_name_or_path)
|
| 296 |
+
self.model = CLIPModel(config)
|
| 297 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 298 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 299 |
+
self.model.to(device=self.device, dtype=self.dtype)
|
| 300 |
+
self.model.requires_grad_(False)
|
| 301 |
+
self.model.eval()
|
| 302 |
+
|
| 303 |
+
def _tokenize(self, caption):
|
| 304 |
+
input_ids = self.tokenizer(
|
| 305 |
+
caption,
|
| 306 |
+
max_length=self.tokenizer.model_max_length,
|
| 307 |
+
padding="max_length",
|
| 308 |
+
truncation=True,
|
| 309 |
+
return_tensors="pt"
|
| 310 |
+
).input_ids
|
| 311 |
+
|
| 312 |
+
return input_ids
|
| 313 |
+
|
| 314 |
+
def __call__(
|
| 315 |
+
self,
|
| 316 |
+
batch_frames: torch.Tensor,
|
| 317 |
+
batch_prompt: list[str],
|
| 318 |
+
batch_condition: Optional[list[str]] = None
|
| 319 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 320 |
+
if batch_condition is None:
|
| 321 |
+
batch_condition = [self.condition] * len(batch_prompt)
|
| 322 |
+
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
|
| 323 |
+
batch_loss, batch_reward = 0, 0
|
| 324 |
+
for frames in batch_frames:
|
| 325 |
+
image_inputs = torch.stack([self.transform(frame) for frame in frames])
|
| 326 |
+
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
|
| 327 |
+
text_inputs = self._tokenize(batch_prompt).to(self.device)
|
| 328 |
+
condition_inputs = self._tokenize(batch_condition).to(device=self.device)
|
| 329 |
+
text_features, image_features = self.model(text_inputs, image_inputs, condition_inputs)
|
| 330 |
+
|
| 331 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 332 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 333 |
+
# reward = self.model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_features))
|
| 334 |
+
logits = image_features @ text_features.T
|
| 335 |
+
reward = torch.diagonal(logits)
|
| 336 |
+
# Convert reward to loss in [0, 1].
|
| 337 |
+
if self.max_reward is None:
|
| 338 |
+
loss = (-1 * reward) * self.loss_scale
|
| 339 |
+
else:
|
| 340 |
+
loss = abs(reward - self.max_reward) * self.loss_scale
|
| 341 |
+
|
| 342 |
+
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
|
| 343 |
+
|
| 344 |
+
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
import numpy as np
|
| 349 |
+
from decord import VideoReader
|
| 350 |
+
|
| 351 |
+
video_path_list = ["your_video_path_1.mp4", "your_video_path_2.mp4"]
|
| 352 |
+
prompt_list = ["your_prompt_1", "your_prompt_2"]
|
| 353 |
+
num_sampled_frames = 8
|
| 354 |
+
|
| 355 |
+
to_tensor = transforms.ToTensor()
|
| 356 |
+
|
| 357 |
+
sampled_frames_list = []
|
| 358 |
+
for video_path in video_path_list:
|
| 359 |
+
vr = VideoReader(video_path)
|
| 360 |
+
sampled_frame_indices = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
|
| 361 |
+
sampled_frames = vr.get_batch(sampled_frame_indices).asnumpy()
|
| 362 |
+
sampled_frames = torch.stack([to_tensor(frame) for frame in sampled_frames])
|
| 363 |
+
sampled_frames_list.append(sampled_frames)
|
| 364 |
+
sampled_frames = torch.stack(sampled_frames_list)
|
| 365 |
+
sampled_frames = rearrange(sampled_frames, "b t c h w -> b c t h w")
|
| 366 |
+
|
| 367 |
+
aesthetic_reward_v2 = AestheticReward(device="cuda", dtype=torch.bfloat16)
|
| 368 |
+
print(f"aesthetic_reward_v2: {aesthetic_reward_v2(sampled_frames)}")
|
| 369 |
+
|
| 370 |
+
aesthetic_reward_v2_5 = AestheticReward(
|
| 371 |
+
encoder_path="google/siglip-so400m-patch14-384", version="v2.5", device="cuda", dtype=torch.bfloat16
|
| 372 |
+
)
|
| 373 |
+
print(f"aesthetic_reward_v2_5: {aesthetic_reward_v2_5(sampled_frames)}")
|
| 374 |
+
|
| 375 |
+
hps_reward_v2 = HPSReward(device="cuda", dtype=torch.bfloat16)
|
| 376 |
+
print(f"hps_reward_v2: {hps_reward_v2(sampled_frames, prompt_list)}")
|
| 377 |
+
|
| 378 |
+
hps_reward_v2_1 = HPSReward(version="v2.1", device="cuda", dtype=torch.bfloat16)
|
| 379 |
+
print(f"hps_reward_v2_1: {hps_reward_v2_1(sampled_frames, prompt_list)}")
|
| 380 |
+
|
| 381 |
+
pick_score = PickScoreReward(device="cuda", dtype=torch.bfloat16)
|
| 382 |
+
print(f"pick_score_reward: {pick_score(sampled_frames, prompt_list)}")
|
| 383 |
+
|
| 384 |
+
mps_score = MPSReward(device="cuda", dtype=torch.bfloat16)
|
| 385 |
+
print(f"mps_reward: {mps_score(sampled_frames, prompt_list)}")
|
videox_fun/ui/cogvideox_fun_ui.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
|
| 2 |
+
"""
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from safetensors import safe_open
|
| 12 |
+
|
| 13 |
+
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
| 14 |
+
from ..models import (AutoencoderKLCogVideoX, CogVideoXTransformer3DModel,
|
| 15 |
+
T5EncoderModel, T5Tokenizer)
|
| 16 |
+
from ..pipeline import (CogVideoXFunControlPipeline,
|
| 17 |
+
CogVideoXFunInpaintPipeline, CogVideoXFunPipeline)
|
| 18 |
+
from ..utils.fp8_optimization import convert_weight_dtype_wrapper
|
| 19 |
+
from ..utils.lora_utils import merge_lora, unmerge_lora
|
| 20 |
+
from ..utils.utils import (get_image_to_video_latent,
|
| 21 |
+
get_video_to_video_latent, save_videos_grid)
|
| 22 |
+
from .controller import (Fun_Controller, Fun_Controller_Client,
|
| 23 |
+
all_cheduler_dict, css, ddpm_scheduler_dict,
|
| 24 |
+
flow_scheduler_dict, gradio_version,
|
| 25 |
+
gradio_version_is_above_4)
|
| 26 |
+
from .ui import (create_cfg_and_seedbox,
|
| 27 |
+
create_fake_finetune_models_checkpoints,
|
| 28 |
+
create_fake_height_width, create_fake_model_checkpoints,
|
| 29 |
+
create_fake_model_type, create_finetune_models_checkpoints,
|
| 30 |
+
create_generation_method,
|
| 31 |
+
create_generation_methods_and_video_length,
|
| 32 |
+
create_height_width, create_model_checkpoints,
|
| 33 |
+
create_model_type, create_prompts, create_samplers,
|
| 34 |
+
create_ui_outputs)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CogVideoXFunController(Fun_Controller):
|
| 38 |
+
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
| 39 |
+
print("Update diffusion transformer")
|
| 40 |
+
self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
|
| 41 |
+
if diffusion_transformer_dropdown == "none":
|
| 42 |
+
return gr.update()
|
| 43 |
+
self.vae = AutoencoderKLCogVideoX.from_pretrained(
|
| 44 |
+
diffusion_transformer_dropdown,
|
| 45 |
+
subfolder="vae",
|
| 46 |
+
).to(self.weight_dtype)
|
| 47 |
+
|
| 48 |
+
# Get Transformer
|
| 49 |
+
self.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
| 50 |
+
diffusion_transformer_dropdown,
|
| 51 |
+
subfolder="transformer",
|
| 52 |
+
low_cpu_mem_usage=True,
|
| 53 |
+
).to(self.weight_dtype)
|
| 54 |
+
|
| 55 |
+
# Get tokenizer and text_encoder
|
| 56 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
| 57 |
+
diffusion_transformer_dropdown, subfolder="tokenizer"
|
| 58 |
+
)
|
| 59 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 60 |
+
diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Get pipeline
|
| 64 |
+
if self.model_type == "Inpaint":
|
| 65 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 66 |
+
self.pipeline = CogVideoXFunInpaintPipeline(
|
| 67 |
+
tokenizer=tokenizer,
|
| 68 |
+
text_encoder=text_encoder,
|
| 69 |
+
vae=self.vae,
|
| 70 |
+
transformer=self.transformer,
|
| 71 |
+
scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
self.pipeline = CogVideoXFunPipeline(
|
| 75 |
+
tokenizer=tokenizer,
|
| 76 |
+
text_encoder=text_encoder,
|
| 77 |
+
vae=self.vae,
|
| 78 |
+
transformer=self.transformer,
|
| 79 |
+
scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
self.pipeline = CogVideoXFunControlPipeline(
|
| 83 |
+
diffusion_transformer_dropdown,
|
| 84 |
+
vae=self.vae,
|
| 85 |
+
transformer=self.transformer,
|
| 86 |
+
scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
|
| 87 |
+
torch_dtype=self.weight_dtype
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if self.ulysses_degree > 1 or self.ring_degree > 1:
|
| 91 |
+
self.transformer.enable_multi_gpus_inference()
|
| 92 |
+
|
| 93 |
+
if self.GPU_memory_mode == "sequential_cpu_offload":
|
| 94 |
+
self.pipeline.enable_sequential_cpu_offload(device=self.device)
|
| 95 |
+
elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
| 96 |
+
convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
|
| 97 |
+
self.pipeline.enable_model_cpu_offload(device=self.device)
|
| 98 |
+
elif self.GPU_memory_mode == "model_cpu_offload":
|
| 99 |
+
self.pipeline.enable_model_cpu_offload(device=self.device)
|
| 100 |
+
else:
|
| 101 |
+
self.pipeline.to(self.device)
|
| 102 |
+
print("Update diffusion transformer done")
|
| 103 |
+
return gr.update()
|
| 104 |
+
|
| 105 |
+
def generate(
|
| 106 |
+
self,
|
| 107 |
+
diffusion_transformer_dropdown,
|
| 108 |
+
base_model_dropdown,
|
| 109 |
+
lora_model_dropdown,
|
| 110 |
+
lora_alpha_slider,
|
| 111 |
+
prompt_textbox,
|
| 112 |
+
negative_prompt_textbox,
|
| 113 |
+
sampler_dropdown,
|
| 114 |
+
sample_step_slider,
|
| 115 |
+
resize_method,
|
| 116 |
+
width_slider,
|
| 117 |
+
height_slider,
|
| 118 |
+
base_resolution,
|
| 119 |
+
generation_method,
|
| 120 |
+
length_slider,
|
| 121 |
+
overlap_video_length,
|
| 122 |
+
partial_video_length,
|
| 123 |
+
cfg_scale_slider,
|
| 124 |
+
start_image,
|
| 125 |
+
end_image,
|
| 126 |
+
validation_video,
|
| 127 |
+
validation_video_mask,
|
| 128 |
+
control_video,
|
| 129 |
+
denoise_strength,
|
| 130 |
+
seed_textbox,
|
| 131 |
+
is_api = False,
|
| 132 |
+
):
|
| 133 |
+
self.clear_cache()
|
| 134 |
+
|
| 135 |
+
self.input_check(
|
| 136 |
+
resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
|
| 137 |
+
)
|
| 138 |
+
is_image = True if generation_method == "Image Generation" else False
|
| 139 |
+
|
| 140 |
+
if self.base_model_path != base_model_dropdown:
|
| 141 |
+
self.update_base_model(base_model_dropdown)
|
| 142 |
+
|
| 143 |
+
if self.lora_model_path != lora_model_dropdown:
|
| 144 |
+
self.update_lora_model(lora_model_dropdown)
|
| 145 |
+
|
| 146 |
+
self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
|
| 147 |
+
|
| 148 |
+
if resize_method == "Resize according to Reference":
|
| 149 |
+
height_slider, width_slider = self.get_height_width_from_reference(
|
| 150 |
+
base_resolution, start_image, validation_video, control_video,
|
| 151 |
+
)
|
| 152 |
+
if self.lora_model_path != "none":
|
| 153 |
+
# lora part
|
| 154 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 155 |
+
|
| 156 |
+
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
| 157 |
+
else: seed_textbox = np.random.randint(0, 1e10)
|
| 158 |
+
generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
if self.model_type == "Inpaint":
|
| 162 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 163 |
+
if generation_method == "Long Video Generation":
|
| 164 |
+
if validation_video is not None:
|
| 165 |
+
raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
|
| 166 |
+
init_frames = 0
|
| 167 |
+
last_frames = init_frames + partial_video_length
|
| 168 |
+
while init_frames < length_slider:
|
| 169 |
+
if last_frames >= length_slider:
|
| 170 |
+
_partial_video_length = length_slider - init_frames
|
| 171 |
+
_partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
|
| 172 |
+
|
| 173 |
+
if _partial_video_length <= 0:
|
| 174 |
+
break
|
| 175 |
+
else:
|
| 176 |
+
_partial_video_length = partial_video_length
|
| 177 |
+
|
| 178 |
+
if last_frames >= length_slider:
|
| 179 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
|
| 180 |
+
else:
|
| 181 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
|
| 182 |
+
|
| 183 |
+
with torch.no_grad():
|
| 184 |
+
sample = self.pipeline(
|
| 185 |
+
prompt_textbox,
|
| 186 |
+
negative_prompt = negative_prompt_textbox,
|
| 187 |
+
num_inference_steps = sample_step_slider,
|
| 188 |
+
guidance_scale = cfg_scale_slider,
|
| 189 |
+
width = width_slider,
|
| 190 |
+
height = height_slider,
|
| 191 |
+
num_frames = _partial_video_length,
|
| 192 |
+
generator = generator,
|
| 193 |
+
|
| 194 |
+
video = input_video,
|
| 195 |
+
mask_video = input_video_mask,
|
| 196 |
+
strength = 1,
|
| 197 |
+
).videos
|
| 198 |
+
|
| 199 |
+
if init_frames != 0:
|
| 200 |
+
mix_ratio = torch.from_numpy(
|
| 201 |
+
np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
|
| 202 |
+
).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 203 |
+
|
| 204 |
+
new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
|
| 205 |
+
sample[:, :, :overlap_video_length] * mix_ratio
|
| 206 |
+
new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
|
| 207 |
+
|
| 208 |
+
sample = new_sample
|
| 209 |
+
else:
|
| 210 |
+
new_sample = sample
|
| 211 |
+
|
| 212 |
+
if last_frames >= length_slider:
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
start_image = [
|
| 216 |
+
Image.fromarray(
|
| 217 |
+
(sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
|
| 218 |
+
) for _index in range(-overlap_video_length, 0)
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
init_frames = init_frames + _partial_video_length - overlap_video_length
|
| 222 |
+
last_frames = init_frames + _partial_video_length
|
| 223 |
+
else:
|
| 224 |
+
if validation_video is not None:
|
| 225 |
+
input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
|
| 226 |
+
strength = denoise_strength
|
| 227 |
+
else:
|
| 228 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
| 229 |
+
strength = 1
|
| 230 |
+
|
| 231 |
+
sample = self.pipeline(
|
| 232 |
+
prompt_textbox,
|
| 233 |
+
negative_prompt = negative_prompt_textbox,
|
| 234 |
+
num_inference_steps = sample_step_slider,
|
| 235 |
+
guidance_scale = cfg_scale_slider,
|
| 236 |
+
width = width_slider,
|
| 237 |
+
height = height_slider,
|
| 238 |
+
num_frames = length_slider if not is_image else 1,
|
| 239 |
+
generator = generator,
|
| 240 |
+
|
| 241 |
+
video = input_video,
|
| 242 |
+
mask_video = input_video_mask,
|
| 243 |
+
strength = strength,
|
| 244 |
+
).videos
|
| 245 |
+
else:
|
| 246 |
+
sample = self.pipeline(
|
| 247 |
+
prompt_textbox,
|
| 248 |
+
negative_prompt = negative_prompt_textbox,
|
| 249 |
+
num_inference_steps = sample_step_slider,
|
| 250 |
+
guidance_scale = cfg_scale_slider,
|
| 251 |
+
width = width_slider,
|
| 252 |
+
height = height_slider,
|
| 253 |
+
num_frames = length_slider if not is_image else 1,
|
| 254 |
+
generator = generator
|
| 255 |
+
).videos
|
| 256 |
+
else:
|
| 257 |
+
input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
|
| 258 |
+
|
| 259 |
+
sample = self.pipeline(
|
| 260 |
+
prompt_textbox,
|
| 261 |
+
negative_prompt = negative_prompt_textbox,
|
| 262 |
+
num_inference_steps = sample_step_slider,
|
| 263 |
+
guidance_scale = cfg_scale_slider,
|
| 264 |
+
width = width_slider,
|
| 265 |
+
height = height_slider,
|
| 266 |
+
num_frames = length_slider if not is_image else 1,
|
| 267 |
+
generator = generator,
|
| 268 |
+
|
| 269 |
+
control_video = input_video,
|
| 270 |
+
).videos
|
| 271 |
+
except Exception as e:
|
| 272 |
+
self.clear_cache()
|
| 273 |
+
if self.lora_model_path != "none":
|
| 274 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 275 |
+
if is_api:
|
| 276 |
+
return "", f"Error. error information is {str(e)}"
|
| 277 |
+
else:
|
| 278 |
+
return gr.update(), gr.update(), f"Error. error information is {str(e)}"
|
| 279 |
+
|
| 280 |
+
self.clear_cache()
|
| 281 |
+
# lora part
|
| 282 |
+
if self.lora_model_path != "none":
|
| 283 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 284 |
+
|
| 285 |
+
save_sample_path = self.save_outputs(
|
| 286 |
+
is_image, length_slider, sample, fps=8
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if is_image or length_slider == 1:
|
| 290 |
+
if is_api:
|
| 291 |
+
return save_sample_path, "Success"
|
| 292 |
+
else:
|
| 293 |
+
if gradio_version_is_above_4:
|
| 294 |
+
return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
|
| 295 |
+
else:
|
| 296 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
| 297 |
+
else:
|
| 298 |
+
if is_api:
|
| 299 |
+
return save_sample_path, "Success"
|
| 300 |
+
else:
|
| 301 |
+
if gradio_version_is_above_4:
|
| 302 |
+
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
|
| 303 |
+
else:
|
| 304 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
| 305 |
+
|
| 306 |
+
CogVideoXFunController_Host = CogVideoXFunController
|
| 307 |
+
CogVideoXFunController_Client = Fun_Controller_Client
|
| 308 |
+
|
| 309 |
+
def ui(GPU_memory_mode, scheduler_dict, ulysses_degree, ring_degree, weight_dtype, savedir_sample=None):
|
| 310 |
+
controller = CogVideoXFunController(
|
| 311 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 312 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 313 |
+
config_path=None, enable_teacache=None, teacache_threshold=None, weight_dtype=weight_dtype,
|
| 314 |
+
savedir_sample=savedir_sample,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
with gr.Blocks(css=css) as demo:
|
| 318 |
+
gr.Markdown(
|
| 319 |
+
"""
|
| 320 |
+
# CogVideoX-Fun:
|
| 321 |
+
|
| 322 |
+
A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
|
| 323 |
+
|
| 324 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 325 |
+
"""
|
| 326 |
+
)
|
| 327 |
+
with gr.Column(variant="panel"):
|
| 328 |
+
model_type = create_model_type(visible=True)
|
| 329 |
+
diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
|
| 330 |
+
create_model_checkpoints(controller, visible=True)
|
| 331 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
|
| 332 |
+
create_finetune_models_checkpoints(controller, visible=True)
|
| 333 |
+
|
| 334 |
+
with gr.Column(variant="panel"):
|
| 335 |
+
prompt_textbox, negative_prompt_textbox = create_prompts()
|
| 336 |
+
|
| 337 |
+
with gr.Row():
|
| 338 |
+
with gr.Column():
|
| 339 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 340 |
+
|
| 341 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width(
|
| 342 |
+
default_height = 384, default_width = 672, maximum_height = 1344,
|
| 343 |
+
maximum_width = 1344,
|
| 344 |
+
)
|
| 345 |
+
gr.Markdown(
|
| 346 |
+
"""
|
| 347 |
+
V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
|
| 348 |
+
(V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
|
| 349 |
+
"""
|
| 350 |
+
)
|
| 351 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 352 |
+
create_generation_methods_and_video_length(
|
| 353 |
+
["Video Generation", "Image Generation", "Long Video Generation"],
|
| 354 |
+
default_video_length=49,
|
| 355 |
+
maximum_video_length=85,
|
| 356 |
+
)
|
| 357 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
|
| 358 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
|
| 359 |
+
)
|
| 360 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 361 |
+
|
| 362 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 363 |
+
|
| 364 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 365 |
+
|
| 366 |
+
model_type.change(
|
| 367 |
+
fn=controller.update_model_type,
|
| 368 |
+
inputs=[model_type],
|
| 369 |
+
outputs=[]
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def upload_generation_method(generation_method):
|
| 373 |
+
if generation_method == "Video Generation":
|
| 374 |
+
return [gr.update(visible=True, maximum=85, value=49, interactive=True), gr.update(visible=False), gr.update(visible=False)]
|
| 375 |
+
elif generation_method == "Image Generation":
|
| 376 |
+
return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
|
| 377 |
+
else:
|
| 378 |
+
return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
|
| 379 |
+
generation_method.change(
|
| 380 |
+
upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
def upload_source_method(source_method):
|
| 384 |
+
if source_method == "Text to Video (文本到视频)":
|
| 385 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 386 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 387 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 388 |
+
elif source_method == "Video to Video (视频到视频)":
|
| 389 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
| 390 |
+
else:
|
| 391 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
| 392 |
+
source_method.change(
|
| 393 |
+
upload_source_method, source_method, [
|
| 394 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
| 395 |
+
validation_video, validation_video_mask, control_video
|
| 396 |
+
]
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
def upload_resize_method(resize_method):
|
| 400 |
+
if resize_method == "Generate by":
|
| 401 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 402 |
+
else:
|
| 403 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 404 |
+
resize_method.change(
|
| 405 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
generate_button.click(
|
| 409 |
+
fn=controller.generate,
|
| 410 |
+
inputs=[
|
| 411 |
+
diffusion_transformer_dropdown,
|
| 412 |
+
base_model_dropdown,
|
| 413 |
+
lora_model_dropdown,
|
| 414 |
+
lora_alpha_slider,
|
| 415 |
+
prompt_textbox,
|
| 416 |
+
negative_prompt_textbox,
|
| 417 |
+
sampler_dropdown,
|
| 418 |
+
sample_step_slider,
|
| 419 |
+
resize_method,
|
| 420 |
+
width_slider,
|
| 421 |
+
height_slider,
|
| 422 |
+
base_resolution,
|
| 423 |
+
generation_method,
|
| 424 |
+
length_slider,
|
| 425 |
+
overlap_video_length,
|
| 426 |
+
partial_video_length,
|
| 427 |
+
cfg_scale_slider,
|
| 428 |
+
start_image,
|
| 429 |
+
end_image,
|
| 430 |
+
validation_video,
|
| 431 |
+
validation_video_mask,
|
| 432 |
+
control_video,
|
| 433 |
+
denoise_strength,
|
| 434 |
+
seed_textbox,
|
| 435 |
+
],
|
| 436 |
+
outputs=[result_image, result_video, infer_progress]
|
| 437 |
+
)
|
| 438 |
+
return demo, controller
|
| 439 |
+
|
| 440 |
+
def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, ulysses_degree, ring_degree, weight_dtype, savedir_sample=None):
|
| 441 |
+
controller = CogVideoXFunController_Host(
|
| 442 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type,
|
| 443 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 444 |
+
config_path=None, enable_teacache=None, teacache_threshold=None, weight_dtype=weight_dtype,
|
| 445 |
+
savedir_sample=savedir_sample,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
with gr.Blocks(css=css) as demo:
|
| 449 |
+
gr.Markdown(
|
| 450 |
+
"""
|
| 451 |
+
# CogVideoX-Fun
|
| 452 |
+
|
| 453 |
+
A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
|
| 454 |
+
|
| 455 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 456 |
+
"""
|
| 457 |
+
)
|
| 458 |
+
with gr.Column(variant="panel"):
|
| 459 |
+
model_type = create_fake_model_type(visible=True)
|
| 460 |
+
diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
|
| 461 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
|
| 462 |
+
|
| 463 |
+
with gr.Column(variant="panel"):
|
| 464 |
+
prompt_textbox, negative_prompt_textbox = create_prompts()
|
| 465 |
+
|
| 466 |
+
with gr.Row():
|
| 467 |
+
with gr.Column():
|
| 468 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 469 |
+
|
| 470 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width(
|
| 471 |
+
default_height = 384, default_width = 672, maximum_height = 1344,
|
| 472 |
+
maximum_width = 1344,
|
| 473 |
+
)
|
| 474 |
+
gr.Markdown(
|
| 475 |
+
"""
|
| 476 |
+
V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
|
| 477 |
+
(V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
|
| 478 |
+
"""
|
| 479 |
+
)
|
| 480 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 481 |
+
create_generation_methods_and_video_length(
|
| 482 |
+
["Video Generation", "Image Generation"],
|
| 483 |
+
default_video_length=49,
|
| 484 |
+
maximum_video_length=85,
|
| 485 |
+
)
|
| 486 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
|
| 487 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
|
| 488 |
+
)
|
| 489 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 490 |
+
|
| 491 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 492 |
+
|
| 493 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 494 |
+
|
| 495 |
+
def upload_generation_method(generation_method):
|
| 496 |
+
if generation_method == "Video Generation":
|
| 497 |
+
return gr.update(visible=True, minimum=8, maximum=85, value=49, interactive=True)
|
| 498 |
+
elif generation_method == "Image Generation":
|
| 499 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
| 500 |
+
generation_method.change(
|
| 501 |
+
upload_generation_method, generation_method, [length_slider]
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
def upload_source_method(source_method):
|
| 505 |
+
if source_method == "Text to Video (文本到视频)":
|
| 506 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 507 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 508 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 509 |
+
elif source_method == "Video to Video (视频到视频)":
|
| 510 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
| 511 |
+
else:
|
| 512 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
| 513 |
+
source_method.change(
|
| 514 |
+
upload_source_method, source_method, [
|
| 515 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
| 516 |
+
validation_video, validation_video_mask, control_video
|
| 517 |
+
]
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
def upload_resize_method(resize_method):
|
| 521 |
+
if resize_method == "Generate by":
|
| 522 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 523 |
+
else:
|
| 524 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 525 |
+
resize_method.change(
|
| 526 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
generate_button.click(
|
| 530 |
+
fn=controller.generate,
|
| 531 |
+
inputs=[
|
| 532 |
+
diffusion_transformer_dropdown,
|
| 533 |
+
base_model_dropdown,
|
| 534 |
+
lora_model_dropdown,
|
| 535 |
+
lora_alpha_slider,
|
| 536 |
+
prompt_textbox,
|
| 537 |
+
negative_prompt_textbox,
|
| 538 |
+
sampler_dropdown,
|
| 539 |
+
sample_step_slider,
|
| 540 |
+
resize_method,
|
| 541 |
+
width_slider,
|
| 542 |
+
height_slider,
|
| 543 |
+
base_resolution,
|
| 544 |
+
generation_method,
|
| 545 |
+
length_slider,
|
| 546 |
+
overlap_video_length,
|
| 547 |
+
partial_video_length,
|
| 548 |
+
cfg_scale_slider,
|
| 549 |
+
start_image,
|
| 550 |
+
end_image,
|
| 551 |
+
validation_video,
|
| 552 |
+
validation_video_mask,
|
| 553 |
+
control_video,
|
| 554 |
+
denoise_strength,
|
| 555 |
+
seed_textbox,
|
| 556 |
+
],
|
| 557 |
+
outputs=[result_image, result_video, infer_progress]
|
| 558 |
+
)
|
| 559 |
+
return demo, controller
|
| 560 |
+
|
| 561 |
+
def ui_client(scheduler_dict, model_name, savedir_sample=None):
|
| 562 |
+
controller = CogVideoXFunController_Client(scheduler_dict, savedir_sample)
|
| 563 |
+
|
| 564 |
+
with gr.Blocks(css=css) as demo:
|
| 565 |
+
gr.Markdown(
|
| 566 |
+
"""
|
| 567 |
+
# CogVideoX-Fun
|
| 568 |
+
|
| 569 |
+
A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
|
| 570 |
+
|
| 571 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 572 |
+
"""
|
| 573 |
+
)
|
| 574 |
+
with gr.Column(variant="panel"):
|
| 575 |
+
diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
|
| 576 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
|
| 577 |
+
|
| 578 |
+
with gr.Column(variant="panel"):
|
| 579 |
+
prompt_textbox, negative_prompt_textbox = create_prompts()
|
| 580 |
+
|
| 581 |
+
with gr.Row():
|
| 582 |
+
with gr.Column():
|
| 583 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
|
| 584 |
+
|
| 585 |
+
resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
|
| 586 |
+
default_height = 384, default_width = 672, maximum_height = 1344,
|
| 587 |
+
maximum_width = 1344,
|
| 588 |
+
)
|
| 589 |
+
gr.Markdown(
|
| 590 |
+
"""
|
| 591 |
+
V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
|
| 592 |
+
(V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
|
| 593 |
+
"""
|
| 594 |
+
)
|
| 595 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 596 |
+
create_generation_methods_and_video_length(
|
| 597 |
+
["Video Generation", "Image Generation"],
|
| 598 |
+
default_video_length=49,
|
| 599 |
+
maximum_video_length=85,
|
| 600 |
+
)
|
| 601 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
|
| 602 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"], prompt_textbox
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 606 |
+
|
| 607 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 608 |
+
|
| 609 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 610 |
+
|
| 611 |
+
def upload_generation_method(generation_method):
|
| 612 |
+
if generation_method == "Video Generation":
|
| 613 |
+
return gr.update(visible=True, minimum=5, maximum=85, value=49, interactive=True)
|
| 614 |
+
elif generation_method == "Image Generation":
|
| 615 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
| 616 |
+
generation_method.change(
|
| 617 |
+
upload_generation_method, generation_method, [length_slider]
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
def upload_source_method(source_method):
|
| 621 |
+
if source_method == "Text to Video (文本到视频)":
|
| 622 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 623 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 624 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
|
| 625 |
+
else:
|
| 626 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
|
| 627 |
+
source_method.change(
|
| 628 |
+
upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
def upload_resize_method(resize_method):
|
| 632 |
+
if resize_method == "Generate by":
|
| 633 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 634 |
+
else:
|
| 635 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 636 |
+
resize_method.change(
|
| 637 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
generate_button.click(
|
| 641 |
+
fn=controller.generate,
|
| 642 |
+
inputs=[
|
| 643 |
+
diffusion_transformer_dropdown,
|
| 644 |
+
base_model_dropdown,
|
| 645 |
+
lora_model_dropdown,
|
| 646 |
+
lora_alpha_slider,
|
| 647 |
+
prompt_textbox,
|
| 648 |
+
negative_prompt_textbox,
|
| 649 |
+
sampler_dropdown,
|
| 650 |
+
sample_step_slider,
|
| 651 |
+
resize_method,
|
| 652 |
+
width_slider,
|
| 653 |
+
height_slider,
|
| 654 |
+
base_resolution,
|
| 655 |
+
generation_method,
|
| 656 |
+
length_slider,
|
| 657 |
+
cfg_scale_slider,
|
| 658 |
+
start_image,
|
| 659 |
+
end_image,
|
| 660 |
+
validation_video,
|
| 661 |
+
validation_video_mask,
|
| 662 |
+
denoise_strength,
|
| 663 |
+
seed_textbox,
|
| 664 |
+
],
|
| 665 |
+
outputs=[result_image, result_video, infer_progress]
|
| 666 |
+
)
|
| 667 |
+
return demo, controller
|
videox_fun/ui/ui.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_model_type(visible):
|
| 7 |
+
gr.Markdown(
|
| 8 |
+
"""
|
| 9 |
+
### Model Type (模型的种类,正常模型还是控制模型).
|
| 10 |
+
""",
|
| 11 |
+
visible=visible,
|
| 12 |
+
)
|
| 13 |
+
with gr.Row():
|
| 14 |
+
model_type = gr.Dropdown(
|
| 15 |
+
label="The model type of the model (模型的种类,正常模型还是控制模型)",
|
| 16 |
+
choices=["Inpaint", "Control"],
|
| 17 |
+
value="Inpaint",
|
| 18 |
+
visible=visible,
|
| 19 |
+
interactive=True,
|
| 20 |
+
)
|
| 21 |
+
return model_type
|
| 22 |
+
|
| 23 |
+
def create_fake_model_type(visible):
|
| 24 |
+
gr.Markdown(
|
| 25 |
+
"""
|
| 26 |
+
### Model Type (模型的种类,正常模型还是控制模型).
|
| 27 |
+
""",
|
| 28 |
+
visible=visible,
|
| 29 |
+
)
|
| 30 |
+
with gr.Row():
|
| 31 |
+
model_type = gr.Dropdown(
|
| 32 |
+
label="The model type of the model (模型的种类,正常模型还是控制模型)",
|
| 33 |
+
choices=["Inpaint", "Control"],
|
| 34 |
+
value="Inpaint",
|
| 35 |
+
interactive=False,
|
| 36 |
+
visible=visible,
|
| 37 |
+
)
|
| 38 |
+
return model_type
|
| 39 |
+
|
| 40 |
+
def create_model_checkpoints(controller, visible):
|
| 41 |
+
gr.Markdown(
|
| 42 |
+
"""
|
| 43 |
+
### Model checkpoints (模型路径).
|
| 44 |
+
"""
|
| 45 |
+
)
|
| 46 |
+
with gr.Row(visible=visible):
|
| 47 |
+
diffusion_transformer_dropdown = gr.Dropdown(
|
| 48 |
+
label="Pretrained Model Path (预训练模型路径)",
|
| 49 |
+
choices=controller.diffusion_transformer_list,
|
| 50 |
+
value="none",
|
| 51 |
+
interactive=True,
|
| 52 |
+
)
|
| 53 |
+
diffusion_transformer_dropdown.change(
|
| 54 |
+
fn=controller.update_diffusion_transformer,
|
| 55 |
+
inputs=[diffusion_transformer_dropdown],
|
| 56 |
+
outputs=[diffusion_transformer_dropdown]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
| 60 |
+
def refresh_diffusion_transformer():
|
| 61 |
+
controller.refresh_diffusion_transformer()
|
| 62 |
+
return gr.update(choices=controller.diffusion_transformer_list)
|
| 63 |
+
diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
|
| 64 |
+
|
| 65 |
+
return diffusion_transformer_dropdown, diffusion_transformer_refresh_button
|
| 66 |
+
|
| 67 |
+
def create_fake_model_checkpoints(model_name, visible):
|
| 68 |
+
gr.Markdown(
|
| 69 |
+
"""
|
| 70 |
+
### Model checkpoints (模型路径).
|
| 71 |
+
"""
|
| 72 |
+
)
|
| 73 |
+
with gr.Row(visible=visible):
|
| 74 |
+
diffusion_transformer_dropdown = gr.Dropdown(
|
| 75 |
+
label="Pretrained Model Path (预训练模型路径)",
|
| 76 |
+
choices=[model_name],
|
| 77 |
+
value=model_name,
|
| 78 |
+
interactive=False,
|
| 79 |
+
)
|
| 80 |
+
return diffusion_transformer_dropdown
|
| 81 |
+
|
| 82 |
+
def create_finetune_models_checkpoints(controller, visible):
|
| 83 |
+
with gr.Row(visible=visible):
|
| 84 |
+
base_model_dropdown = gr.Dropdown(
|
| 85 |
+
label="Select base Dreambooth model (选择基模型[非必需])",
|
| 86 |
+
choices=controller.personalized_model_list,
|
| 87 |
+
value="none",
|
| 88 |
+
interactive=True,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
lora_model_dropdown = gr.Dropdown(
|
| 92 |
+
label="Select LoRA model (选择LoRA模型[非必需])",
|
| 93 |
+
choices=["none"] + controller.personalized_model_list,
|
| 94 |
+
value="none",
|
| 95 |
+
interactive=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
|
| 99 |
+
|
| 100 |
+
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
| 101 |
+
def update_personalized_model():
|
| 102 |
+
controller.refresh_personalized_model()
|
| 103 |
+
return [
|
| 104 |
+
gr.update(choices=controller.personalized_model_list),
|
| 105 |
+
gr.update(choices=["none"] + controller.personalized_model_list)
|
| 106 |
+
]
|
| 107 |
+
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
|
| 108 |
+
|
| 109 |
+
return base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button
|
| 110 |
+
|
| 111 |
+
def create_fake_finetune_models_checkpoints(visible):
|
| 112 |
+
with gr.Row():
|
| 113 |
+
base_model_dropdown = gr.Dropdown(
|
| 114 |
+
label="Select base Dreambooth model (选择基模型[非必需])",
|
| 115 |
+
choices=["none"],
|
| 116 |
+
value="none",
|
| 117 |
+
interactive=False,
|
| 118 |
+
visible=False
|
| 119 |
+
)
|
| 120 |
+
with gr.Column(visible=False):
|
| 121 |
+
gr.Markdown(
|
| 122 |
+
"""
|
| 123 |
+
### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
|
| 124 |
+
"""
|
| 125 |
+
)
|
| 126 |
+
with gr.Row():
|
| 127 |
+
lora_model_dropdown = gr.Dropdown(
|
| 128 |
+
label="Select LoRA model",
|
| 129 |
+
choices=["none"],
|
| 130 |
+
value="none",
|
| 131 |
+
interactive=True,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
|
| 135 |
+
|
| 136 |
+
return base_model_dropdown, lora_model_dropdown, lora_alpha_slider
|
| 137 |
+
|
| 138 |
+
def create_prompts(
|
| 139 |
+
prompt="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
| 140 |
+
negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
|
| 141 |
+
):
|
| 142 |
+
gr.Markdown(
|
| 143 |
+
"""
|
| 144 |
+
### Configs for Generation (生成参数配置).
|
| 145 |
+
"""
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value=prompt)
|
| 149 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value=negative_prompt)
|
| 150 |
+
return prompt_textbox, negative_prompt_textbox
|
| 151 |
+
|
| 152 |
+
def create_samplers(controller, maximum_step=100):
|
| 153 |
+
with gr.Row():
|
| 154 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
|
| 155 |
+
sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=maximum_step, step=1)
|
| 156 |
+
|
| 157 |
+
return sampler_dropdown, sample_step_slider
|
| 158 |
+
|
| 159 |
+
def create_height_width(default_height, default_width, maximum_height, maximum_width):
|
| 160 |
+
resize_method = gr.Radio(
|
| 161 |
+
["Generate by", "Resize according to Reference"],
|
| 162 |
+
value="Generate by",
|
| 163 |
+
show_label=False,
|
| 164 |
+
)
|
| 165 |
+
width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16)
|
| 166 |
+
height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16)
|
| 167 |
+
base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
|
| 168 |
+
|
| 169 |
+
return resize_method, width_slider, height_slider, base_resolution
|
| 170 |
+
|
| 171 |
+
def create_fake_height_width(default_height, default_width, maximum_height, maximum_width):
|
| 172 |
+
resize_method = gr.Radio(
|
| 173 |
+
["Generate by", "Resize according to Reference"],
|
| 174 |
+
value="Generate by",
|
| 175 |
+
show_label=False,
|
| 176 |
+
)
|
| 177 |
+
width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16, interactive=False)
|
| 178 |
+
height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16, interactive=False)
|
| 179 |
+
base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
|
| 180 |
+
|
| 181 |
+
return resize_method, width_slider, height_slider, base_resolution
|
| 182 |
+
|
| 183 |
+
def create_generation_methods_and_video_length(
|
| 184 |
+
generation_method_options,
|
| 185 |
+
default_video_length,
|
| 186 |
+
maximum_video_length
|
| 187 |
+
):
|
| 188 |
+
with gr.Group():
|
| 189 |
+
generation_method = gr.Radio(
|
| 190 |
+
generation_method_options,
|
| 191 |
+
value="Video Generation",
|
| 192 |
+
show_label=False,
|
| 193 |
+
)
|
| 194 |
+
with gr.Row():
|
| 195 |
+
length_slider = gr.Slider(label="Animation length (视频帧数)", value=default_video_length, minimum=1, maximum=maximum_video_length, step=4)
|
| 196 |
+
overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
|
| 197 |
+
partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=maximum_video_length, step=4, visible=False)
|
| 198 |
+
|
| 199 |
+
return generation_method, length_slider, overlap_video_length, partial_video_length
|
| 200 |
+
|
| 201 |
+
def create_generation_method(source_method_options, prompt_textbox, support_end_image=True):
|
| 202 |
+
source_method = gr.Radio(
|
| 203 |
+
source_method_options,
|
| 204 |
+
value="Text to Video (文本到视频)",
|
| 205 |
+
show_label=False,
|
| 206 |
+
)
|
| 207 |
+
with gr.Column(visible = False) as image_to_video_col:
|
| 208 |
+
start_image = gr.Image(
|
| 209 |
+
label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True,
|
| 210 |
+
elem_id="i2v_start", sources="upload", type="filepath",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
| 214 |
+
def select_template(evt: gr.SelectData):
|
| 215 |
+
text = {
|
| 216 |
+
"asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
|
| 217 |
+
"asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
|
| 218 |
+
"asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
|
| 219 |
+
"asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
|
| 220 |
+
"asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
|
| 221 |
+
}[template_gallery_path[evt.index]]
|
| 222 |
+
return template_gallery_path[evt.index], text
|
| 223 |
+
|
| 224 |
+
template_gallery = gr.Gallery(
|
| 225 |
+
template_gallery_path,
|
| 226 |
+
columns=5, rows=1,
|
| 227 |
+
height=140,
|
| 228 |
+
allow_preview=False,
|
| 229 |
+
container=False,
|
| 230 |
+
label="Template Examples",
|
| 231 |
+
)
|
| 232 |
+
template_gallery.select(select_template, None, [start_image, prompt_textbox])
|
| 233 |
+
|
| 234 |
+
with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False, visible=support_end_image):
|
| 235 |
+
end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
|
| 236 |
+
|
| 237 |
+
with gr.Column(visible = False) as video_to_video_col:
|
| 238 |
+
with gr.Row():
|
| 239 |
+
validation_video = gr.Video(
|
| 240 |
+
label="The video to convert (视频转视频的参考视频)", show_label=True,
|
| 241 |
+
elem_id="v2v", sources="upload",
|
| 242 |
+
)
|
| 243 |
+
with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
|
| 244 |
+
gr.Markdown(
|
| 245 |
+
"""
|
| 246 |
+
- Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
|
| 247 |
+
(请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
|
| 248 |
+
"""
|
| 249 |
+
)
|
| 250 |
+
validation_video_mask = gr.Image(
|
| 251 |
+
label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
|
| 252 |
+
show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
|
| 253 |
+
)
|
| 254 |
+
denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
|
| 255 |
+
|
| 256 |
+
with gr.Column(visible = False) as control_video_col:
|
| 257 |
+
gr.Markdown(
|
| 258 |
+
"""
|
| 259 |
+
Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
|
| 260 |
+
"""
|
| 261 |
+
)
|
| 262 |
+
control_video = gr.Video(
|
| 263 |
+
label="The control video (用于提供控制信号的video)", show_label=True,
|
| 264 |
+
elem_id="v2v_control", sources="upload",
|
| 265 |
+
)
|
| 266 |
+
return image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video
|
| 267 |
+
|
| 268 |
+
def create_cfg_and_seedbox(gradio_version_is_above_4):
|
| 269 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
|
| 270 |
+
|
| 271 |
+
with gr.Row():
|
| 272 |
+
seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
|
| 273 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
| 274 |
+
seed_button.click(
|
| 275 |
+
fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
|
| 276 |
+
inputs=[],
|
| 277 |
+
outputs=[seed_textbox]
|
| 278 |
+
)
|
| 279 |
+
return cfg_scale_slider, seed_textbox, seed_button
|
| 280 |
+
|
| 281 |
+
def create_ui_outputs():
|
| 282 |
+
with gr.Column():
|
| 283 |
+
result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
|
| 284 |
+
result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
|
| 285 |
+
infer_progress = gr.Textbox(
|
| 286 |
+
label="Generation Info (生成信息)",
|
| 287 |
+
value="No task currently",
|
| 288 |
+
interactive=False
|
| 289 |
+
)
|
| 290 |
+
return result_image, result_video, infer_progress
|
videox_fun/ui/wan_fun_ui.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
|
| 2 |
+
"""
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from safetensors import safe_open
|
| 13 |
+
|
| 14 |
+
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
| 15 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 16 |
+
WanT5EncoderModel, WanTransformer3DModel)
|
| 17 |
+
from ..models.cache_utils import get_teacache_coefficients
|
| 18 |
+
from ..pipeline import WanFunInpaintPipeline, WanFunPipeline, WanFunControlPipeline
|
| 19 |
+
from ..utils.fp8_optimization import (convert_model_weight_to_float8,
|
| 20 |
+
convert_weight_dtype_wrapper,
|
| 21 |
+
replace_parameters_by_name)
|
| 22 |
+
from ..utils.lora_utils import merge_lora, unmerge_lora
|
| 23 |
+
from ..utils.utils import (filter_kwargs, get_image_to_video_latent,
|
| 24 |
+
get_video_to_video_latent, save_videos_grid)
|
| 25 |
+
from .controller import (Fun_Controller, Fun_Controller_Client,
|
| 26 |
+
all_cheduler_dict, css, ddpm_scheduler_dict,
|
| 27 |
+
flow_scheduler_dict, gradio_version,
|
| 28 |
+
gradio_version_is_above_4)
|
| 29 |
+
from .ui import (create_cfg_and_seedbox,
|
| 30 |
+
create_fake_finetune_models_checkpoints,
|
| 31 |
+
create_fake_height_width, create_fake_model_checkpoints,
|
| 32 |
+
create_fake_model_type, create_finetune_models_checkpoints,
|
| 33 |
+
create_generation_method,
|
| 34 |
+
create_generation_methods_and_video_length,
|
| 35 |
+
create_height_width, create_model_checkpoints,
|
| 36 |
+
create_model_type, create_prompts, create_samplers,
|
| 37 |
+
create_ui_outputs)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Wan_Fun_Controller(Fun_Controller):
|
| 41 |
+
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
| 42 |
+
print("Update diffusion transformer")
|
| 43 |
+
self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
|
| 44 |
+
if diffusion_transformer_dropdown == "none":
|
| 45 |
+
return gr.update()
|
| 46 |
+
self.vae = AutoencoderKLWan.from_pretrained(
|
| 47 |
+
os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')),
|
| 48 |
+
additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']),
|
| 49 |
+
).to(self.weight_dtype)
|
| 50 |
+
|
| 51 |
+
# Get Transformer
|
| 52 |
+
self.transformer = WanTransformer3DModel.from_pretrained(
|
| 53 |
+
os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
|
| 54 |
+
transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
|
| 55 |
+
low_cpu_mem_usage=True,
|
| 56 |
+
torch_dtype=self.weight_dtype,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Get Tokenizer
|
| 60 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 61 |
+
os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Get Text encoder
|
| 65 |
+
self.text_encoder = WanT5EncoderModel.from_pretrained(
|
| 66 |
+
os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
|
| 67 |
+
additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']),
|
| 68 |
+
low_cpu_mem_usage=True,
|
| 69 |
+
torch_dtype=self.weight_dtype,
|
| 70 |
+
)
|
| 71 |
+
self.text_encoder = self.text_encoder.eval()
|
| 72 |
+
|
| 73 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 74 |
+
# Get Clip Image Encoder
|
| 75 |
+
self.clip_image_encoder = CLIPModel.from_pretrained(
|
| 76 |
+
os.path.join(diffusion_transformer_dropdown, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
|
| 77 |
+
).to(self.weight_dtype)
|
| 78 |
+
self.clip_image_encoder = self.clip_image_encoder.eval()
|
| 79 |
+
else:
|
| 80 |
+
self.clip_image_encoder = None
|
| 81 |
+
|
| 82 |
+
Choosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]]
|
| 83 |
+
self.scheduler = Choosen_Scheduler(
|
| 84 |
+
**filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs']))
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Get pipeline
|
| 88 |
+
if self.model_type == "Inpaint":
|
| 89 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 90 |
+
self.pipeline = WanFunInpaintPipeline(
|
| 91 |
+
vae=self.vae,
|
| 92 |
+
tokenizer=self.tokenizer,
|
| 93 |
+
text_encoder=self.text_encoder,
|
| 94 |
+
transformer=self.transformer,
|
| 95 |
+
scheduler=self.scheduler,
|
| 96 |
+
clip_image_encoder=self.clip_image_encoder,
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
self.pipeline = WanFunPipeline(
|
| 100 |
+
vae=self.vae,
|
| 101 |
+
tokenizer=self.tokenizer,
|
| 102 |
+
text_encoder=self.text_encoder,
|
| 103 |
+
transformer=self.transformer,
|
| 104 |
+
scheduler=self.scheduler,
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
self.pipeline = WanFunControlPipeline(
|
| 108 |
+
vae=self.vae,
|
| 109 |
+
tokenizer=self.tokenizer,
|
| 110 |
+
text_encoder=self.text_encoder,
|
| 111 |
+
transformer=self.transformer,
|
| 112 |
+
scheduler=self.scheduler,
|
| 113 |
+
clip_image_encoder=self.clip_image_encoder,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if self.ulysses_degree > 1 or self.ring_degree > 1:
|
| 117 |
+
self.transformer.enable_multi_gpus_inference()
|
| 118 |
+
|
| 119 |
+
if self.GPU_memory_mode == "sequential_cpu_offload":
|
| 120 |
+
replace_parameters_by_name(self.transformer, ["modulation",], device=self.device)
|
| 121 |
+
self.transformer.freqs = self.transformer.freqs.to(device=self.device)
|
| 122 |
+
self.pipeline.enable_sequential_cpu_offload(device=self.device)
|
| 123 |
+
elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
| 124 |
+
convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",])
|
| 125 |
+
convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
|
| 126 |
+
self.pipeline.enable_model_cpu_offload(device=self.device)
|
| 127 |
+
elif self.GPU_memory_mode == "model_cpu_offload":
|
| 128 |
+
self.pipeline.enable_model_cpu_offload(device=self.device)
|
| 129 |
+
else:
|
| 130 |
+
self.pipeline.to(self.device)
|
| 131 |
+
print("Update diffusion transformer done")
|
| 132 |
+
return gr.update()
|
| 133 |
+
|
| 134 |
+
def generate(
|
| 135 |
+
self,
|
| 136 |
+
diffusion_transformer_dropdown,
|
| 137 |
+
base_model_dropdown,
|
| 138 |
+
lora_model_dropdown,
|
| 139 |
+
lora_alpha_slider,
|
| 140 |
+
prompt_textbox,
|
| 141 |
+
negative_prompt_textbox,
|
| 142 |
+
sampler_dropdown,
|
| 143 |
+
sample_step_slider,
|
| 144 |
+
resize_method,
|
| 145 |
+
width_slider,
|
| 146 |
+
height_slider,
|
| 147 |
+
base_resolution,
|
| 148 |
+
generation_method,
|
| 149 |
+
length_slider,
|
| 150 |
+
overlap_video_length,
|
| 151 |
+
partial_video_length,
|
| 152 |
+
cfg_scale_slider,
|
| 153 |
+
start_image,
|
| 154 |
+
end_image,
|
| 155 |
+
validation_video,
|
| 156 |
+
validation_video_mask,
|
| 157 |
+
control_video,
|
| 158 |
+
denoise_strength,
|
| 159 |
+
seed_textbox,
|
| 160 |
+
is_api = False,
|
| 161 |
+
):
|
| 162 |
+
self.clear_cache()
|
| 163 |
+
|
| 164 |
+
self.input_check(
|
| 165 |
+
resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
|
| 166 |
+
)
|
| 167 |
+
is_image = True if generation_method == "Image Generation" else False
|
| 168 |
+
|
| 169 |
+
if self.base_model_path != base_model_dropdown:
|
| 170 |
+
self.update_base_model(base_model_dropdown)
|
| 171 |
+
|
| 172 |
+
if self.lora_model_path != lora_model_dropdown:
|
| 173 |
+
self.update_lora_model(lora_model_dropdown)
|
| 174 |
+
|
| 175 |
+
self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
|
| 176 |
+
|
| 177 |
+
if resize_method == "Resize according to Reference":
|
| 178 |
+
height_slider, width_slider = self.get_height_width_from_reference(
|
| 179 |
+
base_resolution, start_image, validation_video, control_video,
|
| 180 |
+
)
|
| 181 |
+
if self.lora_model_path != "none":
|
| 182 |
+
# lora part
|
| 183 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 184 |
+
|
| 185 |
+
coefficients = get_teacache_coefficients(self.base_model_path) if self.enable_teacache else None
|
| 186 |
+
if coefficients is not None:
|
| 187 |
+
print(f"Enable TeaCache with threshold {self.teacache_threshold} and skip the first {self.num_skip_start_steps} steps.")
|
| 188 |
+
self.pipeline.transformer.enable_teacache(
|
| 189 |
+
coefficients, sample_step_slider, self.teacache_threshold, num_skip_start_steps=self.num_skip_start_steps, offload=self.teacache_offload
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
| 193 |
+
else: seed_textbox = np.random.randint(0, 1e10)
|
| 194 |
+
generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
|
| 195 |
+
|
| 196 |
+
if self.enable_riflex:
|
| 197 |
+
latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1
|
| 198 |
+
self.pipeline.transformer.enable_riflex(k = self.riflex_k, L_test = latent_frames if not is_image else 1)
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
if self.model_type == "Inpaint":
|
| 202 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 203 |
+
if validation_video is not None:
|
| 204 |
+
input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=16)
|
| 205 |
+
else:
|
| 206 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
| 207 |
+
|
| 208 |
+
sample = self.pipeline(
|
| 209 |
+
prompt_textbox,
|
| 210 |
+
negative_prompt = negative_prompt_textbox,
|
| 211 |
+
num_inference_steps = sample_step_slider,
|
| 212 |
+
guidance_scale = cfg_scale_slider,
|
| 213 |
+
width = width_slider,
|
| 214 |
+
height = height_slider,
|
| 215 |
+
num_frames = length_slider if not is_image else 1,
|
| 216 |
+
generator = generator,
|
| 217 |
+
|
| 218 |
+
video = input_video,
|
| 219 |
+
mask_video = input_video_mask,
|
| 220 |
+
clip_image = clip_image
|
| 221 |
+
).videos
|
| 222 |
+
else:
|
| 223 |
+
sample = self.pipeline(
|
| 224 |
+
prompt_textbox,
|
| 225 |
+
negative_prompt = negative_prompt_textbox,
|
| 226 |
+
num_inference_steps = sample_step_slider,
|
| 227 |
+
guidance_scale = cfg_scale_slider,
|
| 228 |
+
width = width_slider,
|
| 229 |
+
height = height_slider,
|
| 230 |
+
num_frames = length_slider if not is_image else 1,
|
| 231 |
+
generator = generator
|
| 232 |
+
).videos
|
| 233 |
+
else:
|
| 234 |
+
input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=16)
|
| 235 |
+
|
| 236 |
+
sample = self.pipeline(
|
| 237 |
+
prompt_textbox,
|
| 238 |
+
negative_prompt = negative_prompt_textbox,
|
| 239 |
+
num_inference_steps = sample_step_slider,
|
| 240 |
+
guidance_scale = cfg_scale_slider,
|
| 241 |
+
width = width_slider,
|
| 242 |
+
height = height_slider,
|
| 243 |
+
num_frames = length_slider if not is_image else 1,
|
| 244 |
+
generator = generator,
|
| 245 |
+
|
| 246 |
+
control_video = input_video,
|
| 247 |
+
).videos
|
| 248 |
+
except Exception as e:
|
| 249 |
+
self.clear_cache()
|
| 250 |
+
if self.lora_model_path != "none":
|
| 251 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 252 |
+
if is_api:
|
| 253 |
+
return "", f"Error. error information is {str(e)}"
|
| 254 |
+
else:
|
| 255 |
+
return gr.update(), gr.update(), f"Error. error information is {str(e)}"
|
| 256 |
+
|
| 257 |
+
self.clear_cache()
|
| 258 |
+
# lora part
|
| 259 |
+
if self.lora_model_path != "none":
|
| 260 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 261 |
+
|
| 262 |
+
save_sample_path = self.save_outputs(
|
| 263 |
+
is_image, length_slider, sample, fps=16
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
if is_image or length_slider == 1:
|
| 267 |
+
if is_api:
|
| 268 |
+
return save_sample_path, "Success"
|
| 269 |
+
else:
|
| 270 |
+
if gradio_version_is_above_4:
|
| 271 |
+
return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
|
| 272 |
+
else:
|
| 273 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
| 274 |
+
else:
|
| 275 |
+
if is_api:
|
| 276 |
+
return save_sample_path, "Success"
|
| 277 |
+
else:
|
| 278 |
+
if gradio_version_is_above_4:
|
| 279 |
+
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
|
| 280 |
+
else:
|
| 281 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
| 282 |
+
|
| 283 |
+
Wan_Fun_Controller_Host = Wan_Fun_Controller
|
| 284 |
+
Wan_Fun_Controller_Client = Fun_Controller_Client
|
| 285 |
+
|
| 286 |
+
def ui(GPU_memory_mode, scheduler_dict, config_path, ulysses_degree, ring_degree, enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload, enable_riflex, riflex_k, weight_dtype, savedir_sample=None):
|
| 287 |
+
controller = Wan_Fun_Controller(
|
| 288 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 289 |
+
config_path=config_path, ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 290 |
+
enable_teacache=enable_teacache, teacache_threshold=teacache_threshold,
|
| 291 |
+
num_skip_start_steps=num_skip_start_steps, teacache_offload=teacache_offload,
|
| 292 |
+
enable_riflex=enable_riflex, riflex_k=riflex_k, weight_dtype=weight_dtype,
|
| 293 |
+
savedir_sample=savedir_sample,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
with gr.Blocks(css=css) as demo:
|
| 297 |
+
gr.Markdown(
|
| 298 |
+
"""
|
| 299 |
+
# Wan-Fun:
|
| 300 |
+
|
| 301 |
+
A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
|
| 302 |
+
|
| 303 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 304 |
+
"""
|
| 305 |
+
)
|
| 306 |
+
with gr.Column(variant="panel"):
|
| 307 |
+
model_type = create_model_type(visible=True)
|
| 308 |
+
diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
|
| 309 |
+
create_model_checkpoints(controller, visible=True)
|
| 310 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
|
| 311 |
+
create_finetune_models_checkpoints(controller, visible=True)
|
| 312 |
+
|
| 313 |
+
with gr.Column(variant="panel"):
|
| 314 |
+
prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
|
| 315 |
+
|
| 316 |
+
with gr.Row():
|
| 317 |
+
with gr.Column():
|
| 318 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 319 |
+
|
| 320 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width(
|
| 321 |
+
default_height = 480, default_width = 832, maximum_height = 1344,
|
| 322 |
+
maximum_width = 1344,
|
| 323 |
+
)
|
| 324 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 325 |
+
create_generation_methods_and_video_length(
|
| 326 |
+
["Video Generation", "Image Generation"],
|
| 327 |
+
default_video_length=81,
|
| 328 |
+
maximum_video_length=81,
|
| 329 |
+
)
|
| 330 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
|
| 331 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
|
| 332 |
+
)
|
| 333 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 334 |
+
|
| 335 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 336 |
+
|
| 337 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 338 |
+
|
| 339 |
+
model_type.change(
|
| 340 |
+
fn=controller.update_model_type,
|
| 341 |
+
inputs=[model_type],
|
| 342 |
+
outputs=[]
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def upload_generation_method(generation_method):
|
| 346 |
+
if generation_method == "Video Generation":
|
| 347 |
+
return [gr.update(visible=True, maximum=81, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)]
|
| 348 |
+
elif generation_method == "Image Generation":
|
| 349 |
+
return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
|
| 350 |
+
else:
|
| 351 |
+
return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
|
| 352 |
+
generation_method.change(
|
| 353 |
+
upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def upload_source_method(source_method):
|
| 357 |
+
if source_method == "Text to Video (文本到视频)":
|
| 358 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 359 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 360 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 361 |
+
elif source_method == "Video to Video (视频到视频)":
|
| 362 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
| 363 |
+
else:
|
| 364 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
| 365 |
+
source_method.change(
|
| 366 |
+
upload_source_method, source_method, [
|
| 367 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
| 368 |
+
validation_video, validation_video_mask, control_video
|
| 369 |
+
]
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def upload_resize_method(resize_method):
|
| 373 |
+
if resize_method == "Generate by":
|
| 374 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 375 |
+
else:
|
| 376 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 377 |
+
resize_method.change(
|
| 378 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
generate_button.click(
|
| 382 |
+
fn=controller.generate,
|
| 383 |
+
inputs=[
|
| 384 |
+
diffusion_transformer_dropdown,
|
| 385 |
+
base_model_dropdown,
|
| 386 |
+
lora_model_dropdown,
|
| 387 |
+
lora_alpha_slider,
|
| 388 |
+
prompt_textbox,
|
| 389 |
+
negative_prompt_textbox,
|
| 390 |
+
sampler_dropdown,
|
| 391 |
+
sample_step_slider,
|
| 392 |
+
resize_method,
|
| 393 |
+
width_slider,
|
| 394 |
+
height_slider,
|
| 395 |
+
base_resolution,
|
| 396 |
+
generation_method,
|
| 397 |
+
length_slider,
|
| 398 |
+
overlap_video_length,
|
| 399 |
+
partial_video_length,
|
| 400 |
+
cfg_scale_slider,
|
| 401 |
+
start_image,
|
| 402 |
+
end_image,
|
| 403 |
+
validation_video,
|
| 404 |
+
validation_video_mask,
|
| 405 |
+
control_video,
|
| 406 |
+
denoise_strength,
|
| 407 |
+
seed_textbox,
|
| 408 |
+
],
|
| 409 |
+
outputs=[result_image, result_video, infer_progress]
|
| 410 |
+
)
|
| 411 |
+
return demo, controller
|
| 412 |
+
|
| 413 |
+
def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, ulysses_degree, ring_degree, enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload, enable_riflex, riflex_k, weight_dtype, savedir_sample=None):
|
| 414 |
+
controller = Wan_Fun_Controller_Host(
|
| 415 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type,
|
| 416 |
+
config_path=config_path, ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 417 |
+
enable_teacache=enable_teacache, teacache_threshold=teacache_threshold,
|
| 418 |
+
num_skip_start_steps=num_skip_start_steps, teacache_offload=teacache_offload,
|
| 419 |
+
enable_riflex=enable_riflex, riflex_k=riflex_k, weight_dtype=weight_dtype,
|
| 420 |
+
savedir_sample=savedir_sample,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
with gr.Blocks(css=css) as demo:
|
| 424 |
+
gr.Markdown(
|
| 425 |
+
"""
|
| 426 |
+
# Wan-Fun:
|
| 427 |
+
|
| 428 |
+
A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
|
| 429 |
+
|
| 430 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 431 |
+
"""
|
| 432 |
+
)
|
| 433 |
+
with gr.Column(variant="panel"):
|
| 434 |
+
model_type = create_fake_model_type(visible=True)
|
| 435 |
+
diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
|
| 436 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
|
| 437 |
+
|
| 438 |
+
with gr.Column(variant="panel"):
|
| 439 |
+
prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
|
| 440 |
+
|
| 441 |
+
with gr.Row():
|
| 442 |
+
with gr.Column():
|
| 443 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 444 |
+
|
| 445 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width(
|
| 446 |
+
default_height = 480, default_width = 832, maximum_height = 1344,
|
| 447 |
+
maximum_width = 1344,
|
| 448 |
+
)
|
| 449 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 450 |
+
create_generation_methods_and_video_length(
|
| 451 |
+
["Video Generation", "Image Generation"],
|
| 452 |
+
default_video_length=81,
|
| 453 |
+
maximum_video_length=81,
|
| 454 |
+
)
|
| 455 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
|
| 456 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
|
| 457 |
+
)
|
| 458 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 459 |
+
|
| 460 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 461 |
+
|
| 462 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 463 |
+
|
| 464 |
+
def upload_generation_method(generation_method):
|
| 465 |
+
if generation_method == "Video Generation":
|
| 466 |
+
return gr.update(visible=True, minimum=1, maximum=81, value=81, interactive=True)
|
| 467 |
+
elif generation_method == "Image Generation":
|
| 468 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
| 469 |
+
generation_method.change(
|
| 470 |
+
upload_generation_method, generation_method, [length_slider]
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
def upload_source_method(source_method):
|
| 474 |
+
if source_method == "Text to Video (文本到视频)":
|
| 475 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 476 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 477 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 478 |
+
elif source_method == "Video to Video (视频到视频)":
|
| 479 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
| 480 |
+
else:
|
| 481 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
| 482 |
+
source_method.change(
|
| 483 |
+
upload_source_method, source_method, [
|
| 484 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
| 485 |
+
validation_video, validation_video_mask, control_video
|
| 486 |
+
]
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
def upload_resize_method(resize_method):
|
| 490 |
+
if resize_method == "Generate by":
|
| 491 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 492 |
+
else:
|
| 493 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 494 |
+
resize_method.change(
|
| 495 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
generate_button.click(
|
| 499 |
+
fn=controller.generate,
|
| 500 |
+
inputs=[
|
| 501 |
+
diffusion_transformer_dropdown,
|
| 502 |
+
base_model_dropdown,
|
| 503 |
+
lora_model_dropdown,
|
| 504 |
+
lora_alpha_slider,
|
| 505 |
+
prompt_textbox,
|
| 506 |
+
negative_prompt_textbox,
|
| 507 |
+
sampler_dropdown,
|
| 508 |
+
sample_step_slider,
|
| 509 |
+
resize_method,
|
| 510 |
+
width_slider,
|
| 511 |
+
height_slider,
|
| 512 |
+
base_resolution,
|
| 513 |
+
generation_method,
|
| 514 |
+
length_slider,
|
| 515 |
+
overlap_video_length,
|
| 516 |
+
partial_video_length,
|
| 517 |
+
cfg_scale_slider,
|
| 518 |
+
start_image,
|
| 519 |
+
end_image,
|
| 520 |
+
validation_video,
|
| 521 |
+
validation_video_mask,
|
| 522 |
+
control_video,
|
| 523 |
+
denoise_strength,
|
| 524 |
+
seed_textbox,
|
| 525 |
+
],
|
| 526 |
+
outputs=[result_image, result_video, infer_progress]
|
| 527 |
+
)
|
| 528 |
+
return demo, controller
|
| 529 |
+
|
| 530 |
+
def ui_client(scheduler_dict, model_name, savedir_sample=None):
|
| 531 |
+
controller = Wan_Fun_Controller_Client(scheduler_dict, savedir_sample)
|
| 532 |
+
|
| 533 |
+
with gr.Blocks(css=css) as demo:
|
| 534 |
+
gr.Markdown(
|
| 535 |
+
"""
|
| 536 |
+
# Wan-Fun:
|
| 537 |
+
|
| 538 |
+
A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
|
| 539 |
+
|
| 540 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 541 |
+
"""
|
| 542 |
+
)
|
| 543 |
+
with gr.Column(variant="panel"):
|
| 544 |
+
diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
|
| 545 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
|
| 546 |
+
|
| 547 |
+
with gr.Column(variant="panel"):
|
| 548 |
+
prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
|
| 549 |
+
|
| 550 |
+
with gr.Row():
|
| 551 |
+
with gr.Column():
|
| 552 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
|
| 553 |
+
|
| 554 |
+
resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
|
| 555 |
+
default_height = 480, default_width = 832, maximum_height = 1344,
|
| 556 |
+
maximum_width = 1344,
|
| 557 |
+
)
|
| 558 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 559 |
+
create_generation_methods_and_video_length(
|
| 560 |
+
["Video Generation", "Image Generation"],
|
| 561 |
+
default_video_length=81,
|
| 562 |
+
maximum_video_length=81,
|
| 563 |
+
)
|
| 564 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
|
| 565 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 569 |
+
|
| 570 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 571 |
+
|
| 572 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 573 |
+
|
| 574 |
+
def upload_generation_method(generation_method):
|
| 575 |
+
if generation_method == "Video Generation":
|
| 576 |
+
return gr.update(visible=True, minimum=5, maximum=81, value=49, interactive=True)
|
| 577 |
+
elif generation_method == "Image Generation":
|
| 578 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
| 579 |
+
generation_method.change(
|
| 580 |
+
upload_generation_method, generation_method, [length_slider]
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
def upload_source_method(source_method):
|
| 584 |
+
if source_method == "Text to Video (文本到视频)":
|
| 585 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 586 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 587 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
|
| 588 |
+
else:
|
| 589 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
|
| 590 |
+
source_method.change(
|
| 591 |
+
upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
def upload_resize_method(resize_method):
|
| 595 |
+
if resize_method == "Generate by":
|
| 596 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 597 |
+
else:
|
| 598 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 599 |
+
resize_method.change(
|
| 600 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
generate_button.click(
|
| 604 |
+
fn=controller.generate,
|
| 605 |
+
inputs=[
|
| 606 |
+
diffusion_transformer_dropdown,
|
| 607 |
+
base_model_dropdown,
|
| 608 |
+
lora_model_dropdown,
|
| 609 |
+
lora_alpha_slider,
|
| 610 |
+
prompt_textbox,
|
| 611 |
+
negative_prompt_textbox,
|
| 612 |
+
sampler_dropdown,
|
| 613 |
+
sample_step_slider,
|
| 614 |
+
resize_method,
|
| 615 |
+
width_slider,
|
| 616 |
+
height_slider,
|
| 617 |
+
base_resolution,
|
| 618 |
+
generation_method,
|
| 619 |
+
length_slider,
|
| 620 |
+
cfg_scale_slider,
|
| 621 |
+
start_image,
|
| 622 |
+
end_image,
|
| 623 |
+
validation_video,
|
| 624 |
+
validation_video_mask,
|
| 625 |
+
denoise_strength,
|
| 626 |
+
seed_textbox,
|
| 627 |
+
],
|
| 628 |
+
outputs=[result_image, result_video, infer_progress]
|
| 629 |
+
)
|
| 630 |
+
return demo, controller
|
videox_fun/utils/__init__.py
ADDED
|
File without changes
|
videox_fun/utils/discrete_sampler.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
|
| 2 |
+
"""
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class DiscreteSampling:
|
| 6 |
+
def __init__(self, num_idx, uniform_sampling=False):
|
| 7 |
+
self.num_idx = num_idx
|
| 8 |
+
self.uniform_sampling = uniform_sampling
|
| 9 |
+
self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
| 10 |
+
|
| 11 |
+
if self.is_distributed and self.uniform_sampling:
|
| 12 |
+
world_size = torch.distributed.get_world_size()
|
| 13 |
+
self.rank = torch.distributed.get_rank()
|
| 14 |
+
|
| 15 |
+
i = 1
|
| 16 |
+
while True:
|
| 17 |
+
if world_size % i != 0 or num_idx % (world_size // i) != 0:
|
| 18 |
+
i += 1
|
| 19 |
+
else:
|
| 20 |
+
self.group_num = world_size // i
|
| 21 |
+
break
|
| 22 |
+
assert self.group_num > 0
|
| 23 |
+
assert world_size % self.group_num == 0
|
| 24 |
+
# the number of rank in one group
|
| 25 |
+
self.group_width = world_size // self.group_num
|
| 26 |
+
self.sigma_interval = self.num_idx // self.group_num
|
| 27 |
+
print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
|
| 28 |
+
self.rank, world_size, self.group_num,
|
| 29 |
+
self.group_width, self.sigma_interval))
|
| 30 |
+
|
| 31 |
+
def __call__(self, n_samples, generator=None, device=None):
|
| 32 |
+
if self.is_distributed and self.uniform_sampling:
|
| 33 |
+
group_index = self.rank // self.group_width
|
| 34 |
+
idx = torch.randint(
|
| 35 |
+
group_index * self.sigma_interval,
|
| 36 |
+
(group_index + 1) * self.sigma_interval,
|
| 37 |
+
(n_samples,),
|
| 38 |
+
generator=generator, device=device,
|
| 39 |
+
)
|
| 40 |
+
print('proc[%d] idx=%s' % (self.rank, idx))
|
| 41 |
+
else:
|
| 42 |
+
idx = torch.randint(
|
| 43 |
+
0, self.num_idx, (n_samples,),
|
| 44 |
+
generator=generator, device=device,
|
| 45 |
+
)
|
| 46 |
+
return idx
|
videox_fun/utils/fp8_optimization.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper
|
| 2 |
+
"""
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
|
| 7 |
+
weight_dtype = cls.weight.dtype
|
| 8 |
+
cls.to(origin_dtype)
|
| 9 |
+
|
| 10 |
+
# Convert all inputs to the original dtype
|
| 11 |
+
inputs = [input.to(origin_dtype) for input in inputs]
|
| 12 |
+
out = cls.original_forward(*inputs, **kwargs)
|
| 13 |
+
|
| 14 |
+
cls.to(weight_dtype)
|
| 15 |
+
return out
|
| 16 |
+
|
| 17 |
+
def replace_parameters_by_name(module, name_keywords, device):
|
| 18 |
+
from torch import nn
|
| 19 |
+
for name, param in list(module.named_parameters(recurse=False)):
|
| 20 |
+
if any(keyword in name for keyword in name_keywords):
|
| 21 |
+
if isinstance(param, nn.Parameter):
|
| 22 |
+
tensor = param.data
|
| 23 |
+
delattr(module, name)
|
| 24 |
+
setattr(module, name, tensor.to(device=device))
|
| 25 |
+
for child_name, child_module in module.named_children():
|
| 26 |
+
replace_parameters_by_name(child_module, name_keywords, device)
|
| 27 |
+
|
| 28 |
+
def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']):
|
| 29 |
+
for name, module in model.named_modules():
|
| 30 |
+
flag = False
|
| 31 |
+
for _exclude_module_name in exclude_module_name:
|
| 32 |
+
if _exclude_module_name in name:
|
| 33 |
+
flag = True
|
| 34 |
+
if flag:
|
| 35 |
+
continue
|
| 36 |
+
for param_name, param in module.named_parameters():
|
| 37 |
+
flag = False
|
| 38 |
+
for _exclude_module_name in exclude_module_name:
|
| 39 |
+
if _exclude_module_name in param_name:
|
| 40 |
+
flag = True
|
| 41 |
+
if flag:
|
| 42 |
+
continue
|
| 43 |
+
param.data = param.data.to(torch.float8_e4m3fn)
|
| 44 |
+
|
| 45 |
+
def convert_weight_dtype_wrapper(module, origin_dtype):
|
| 46 |
+
for name, module in module.named_modules():
|
| 47 |
+
if name == "" or "embed_tokens" in name:
|
| 48 |
+
continue
|
| 49 |
+
original_forward = module.forward
|
| 50 |
+
if hasattr(module, "weight") and module.weight is not None:
|
| 51 |
+
setattr(module, "original_forward", original_forward)
|
| 52 |
+
setattr(
|
| 53 |
+
module,
|
| 54 |
+
"forward",
|
| 55 |
+
lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
|
| 56 |
+
)
|
videox_fun/utils/lora_utils.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LoRA network module
|
| 2 |
+
# reference:
|
| 3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
| 4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
| 5 |
+
# https://github.com/bmaltais/kohya_ss
|
| 6 |
+
|
| 7 |
+
import hashlib
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
from typing import List, Optional, Type, Union
|
| 13 |
+
|
| 14 |
+
import safetensors.torch
|
| 15 |
+
import torch
|
| 16 |
+
import torch.utils.checkpoint
|
| 17 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
| 18 |
+
from safetensors.torch import load_file
|
| 19 |
+
from transformers import T5EncoderModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LoRAModule(torch.nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
lora_name,
|
| 30 |
+
org_module: torch.nn.Module,
|
| 31 |
+
multiplier=1.0,
|
| 32 |
+
lora_dim=4,
|
| 33 |
+
alpha=1,
|
| 34 |
+
dropout=None,
|
| 35 |
+
rank_dropout=None,
|
| 36 |
+
module_dropout=None,
|
| 37 |
+
):
|
| 38 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.lora_name = lora_name
|
| 41 |
+
|
| 42 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 43 |
+
in_dim = org_module.in_channels
|
| 44 |
+
out_dim = org_module.out_channels
|
| 45 |
+
else:
|
| 46 |
+
in_dim = org_module.in_features
|
| 47 |
+
out_dim = org_module.out_features
|
| 48 |
+
|
| 49 |
+
self.lora_dim = lora_dim
|
| 50 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 51 |
+
kernel_size = org_module.kernel_size
|
| 52 |
+
stride = org_module.stride
|
| 53 |
+
padding = org_module.padding
|
| 54 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
| 55 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
| 56 |
+
else:
|
| 57 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| 58 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
| 59 |
+
|
| 60 |
+
if type(alpha) == torch.Tensor:
|
| 61 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 62 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 63 |
+
self.scale = alpha / self.lora_dim
|
| 64 |
+
self.register_buffer("alpha", torch.tensor(alpha))
|
| 65 |
+
|
| 66 |
+
# same as microsoft's
|
| 67 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| 68 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
| 69 |
+
|
| 70 |
+
self.multiplier = multiplier
|
| 71 |
+
self.org_module = org_module # remove in applying
|
| 72 |
+
self.dropout = dropout
|
| 73 |
+
self.rank_dropout = rank_dropout
|
| 74 |
+
self.module_dropout = module_dropout
|
| 75 |
+
|
| 76 |
+
def apply_to(self):
|
| 77 |
+
self.org_forward = self.org_module.forward
|
| 78 |
+
self.org_module.forward = self.forward
|
| 79 |
+
del self.org_module
|
| 80 |
+
|
| 81 |
+
def forward(self, x, *args, **kwargs):
|
| 82 |
+
weight_dtype = x.dtype
|
| 83 |
+
org_forwarded = self.org_forward(x)
|
| 84 |
+
|
| 85 |
+
# module dropout
|
| 86 |
+
if self.module_dropout is not None and self.training:
|
| 87 |
+
if torch.rand(1) < self.module_dropout:
|
| 88 |
+
return org_forwarded
|
| 89 |
+
|
| 90 |
+
lx = self.lora_down(x.to(self.lora_down.weight.dtype))
|
| 91 |
+
|
| 92 |
+
# normal dropout
|
| 93 |
+
if self.dropout is not None and self.training:
|
| 94 |
+
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
| 95 |
+
|
| 96 |
+
# rank dropout
|
| 97 |
+
if self.rank_dropout is not None and self.training:
|
| 98 |
+
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
| 99 |
+
if len(lx.size()) == 3:
|
| 100 |
+
mask = mask.unsqueeze(1) # for Text Encoder
|
| 101 |
+
elif len(lx.size()) == 4:
|
| 102 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
| 103 |
+
lx = lx * mask
|
| 104 |
+
|
| 105 |
+
# scaling for rank dropout: treat as if the rank is changed
|
| 106 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
| 107 |
+
else:
|
| 108 |
+
scale = self.scale
|
| 109 |
+
|
| 110 |
+
lx = self.lora_up(lx)
|
| 111 |
+
|
| 112 |
+
return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def addnet_hash_legacy(b):
|
| 116 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
| 117 |
+
m = hashlib.sha256()
|
| 118 |
+
|
| 119 |
+
b.seek(0x100000)
|
| 120 |
+
m.update(b.read(0x10000))
|
| 121 |
+
return m.hexdigest()[0:8]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def addnet_hash_safetensors(b):
|
| 125 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
| 126 |
+
hash_sha256 = hashlib.sha256()
|
| 127 |
+
blksize = 1024 * 1024
|
| 128 |
+
|
| 129 |
+
b.seek(0)
|
| 130 |
+
header = b.read(8)
|
| 131 |
+
n = int.from_bytes(header, "little")
|
| 132 |
+
|
| 133 |
+
offset = n + 8
|
| 134 |
+
b.seek(offset)
|
| 135 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
| 136 |
+
hash_sha256.update(chunk)
|
| 137 |
+
|
| 138 |
+
return hash_sha256.hexdigest()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
| 142 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
| 143 |
+
save time on indexing the model later."""
|
| 144 |
+
|
| 145 |
+
# Because writing user metadata to the file can change the result of
|
| 146 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
| 147 |
+
# calculating the hash, as they are meant to be immutable
|
| 148 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
| 149 |
+
|
| 150 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
| 151 |
+
b = BytesIO(bytes)
|
| 152 |
+
|
| 153 |
+
model_hash = addnet_hash_safetensors(b)
|
| 154 |
+
legacy_hash = addnet_hash_legacy(b)
|
| 155 |
+
return model_hash, legacy_hash
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class LoRANetwork(torch.nn.Module):
|
| 159 |
+
TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel", "WanTransformer3DModel"]
|
| 160 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"]
|
| 161 |
+
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
| 162 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
|
| 166 |
+
unet,
|
| 167 |
+
multiplier: float = 1.0,
|
| 168 |
+
lora_dim: int = 4,
|
| 169 |
+
alpha: float = 1,
|
| 170 |
+
dropout: Optional[float] = None,
|
| 171 |
+
module_class: Type[object] = LoRAModule,
|
| 172 |
+
skip_name: str = None,
|
| 173 |
+
varbose: Optional[bool] = False,
|
| 174 |
+
) -> None:
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.multiplier = multiplier
|
| 177 |
+
|
| 178 |
+
self.lora_dim = lora_dim
|
| 179 |
+
self.alpha = alpha
|
| 180 |
+
self.dropout = dropout
|
| 181 |
+
|
| 182 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
| 183 |
+
print(f"neuron dropout: p={self.dropout}")
|
| 184 |
+
|
| 185 |
+
# create module instances
|
| 186 |
+
def create_modules(
|
| 187 |
+
is_unet: bool,
|
| 188 |
+
root_module: torch.nn.Module,
|
| 189 |
+
target_replace_modules: List[torch.nn.Module],
|
| 190 |
+
) -> List[LoRAModule]:
|
| 191 |
+
prefix = (
|
| 192 |
+
self.LORA_PREFIX_TRANSFORMER
|
| 193 |
+
if is_unet
|
| 194 |
+
else self.LORA_PREFIX_TEXT_ENCODER
|
| 195 |
+
)
|
| 196 |
+
loras = []
|
| 197 |
+
skipped = []
|
| 198 |
+
for name, module in root_module.named_modules():
|
| 199 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 200 |
+
for child_name, child_module in module.named_modules():
|
| 201 |
+
is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
| 202 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
| 203 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
| 204 |
+
|
| 205 |
+
if skip_name is not None and skip_name in child_name:
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
if is_linear or is_conv2d:
|
| 209 |
+
lora_name = prefix + "." + name + "." + child_name
|
| 210 |
+
lora_name = lora_name.replace(".", "_")
|
| 211 |
+
|
| 212 |
+
dim = None
|
| 213 |
+
alpha = None
|
| 214 |
+
|
| 215 |
+
if is_linear or is_conv2d_1x1:
|
| 216 |
+
dim = self.lora_dim
|
| 217 |
+
alpha = self.alpha
|
| 218 |
+
|
| 219 |
+
if dim is None or dim == 0:
|
| 220 |
+
if is_linear or is_conv2d_1x1:
|
| 221 |
+
skipped.append(lora_name)
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
lora = module_class(
|
| 225 |
+
lora_name,
|
| 226 |
+
child_module,
|
| 227 |
+
self.multiplier,
|
| 228 |
+
dim,
|
| 229 |
+
alpha,
|
| 230 |
+
dropout=dropout,
|
| 231 |
+
)
|
| 232 |
+
loras.append(lora)
|
| 233 |
+
return loras, skipped
|
| 234 |
+
|
| 235 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
| 236 |
+
|
| 237 |
+
self.text_encoder_loras = []
|
| 238 |
+
skipped_te = []
|
| 239 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 240 |
+
if text_encoder is not None:
|
| 241 |
+
text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 242 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
| 243 |
+
skipped_te += skipped
|
| 244 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 245 |
+
|
| 246 |
+
self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
|
| 247 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 248 |
+
|
| 249 |
+
# assertion
|
| 250 |
+
names = set()
|
| 251 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 252 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
| 253 |
+
names.add(lora.lora_name)
|
| 254 |
+
|
| 255 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
| 256 |
+
if apply_text_encoder:
|
| 257 |
+
print("enable LoRA for text encoder")
|
| 258 |
+
else:
|
| 259 |
+
self.text_encoder_loras = []
|
| 260 |
+
|
| 261 |
+
if apply_unet:
|
| 262 |
+
print("enable LoRA for U-Net")
|
| 263 |
+
else:
|
| 264 |
+
self.unet_loras = []
|
| 265 |
+
|
| 266 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 267 |
+
lora.apply_to()
|
| 268 |
+
self.add_module(lora.lora_name, lora)
|
| 269 |
+
|
| 270 |
+
def set_multiplier(self, multiplier):
|
| 271 |
+
self.multiplier = multiplier
|
| 272 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 273 |
+
lora.multiplier = self.multiplier
|
| 274 |
+
|
| 275 |
+
def load_weights(self, file):
|
| 276 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 277 |
+
from safetensors.torch import load_file
|
| 278 |
+
|
| 279 |
+
weights_sd = load_file(file)
|
| 280 |
+
else:
|
| 281 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 282 |
+
info = self.load_state_dict(weights_sd, False)
|
| 283 |
+
return info
|
| 284 |
+
|
| 285 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
| 286 |
+
self.requires_grad_(True)
|
| 287 |
+
all_params = []
|
| 288 |
+
|
| 289 |
+
def enumerate_params(loras):
|
| 290 |
+
params = []
|
| 291 |
+
for lora in loras:
|
| 292 |
+
params.extend(lora.parameters())
|
| 293 |
+
return params
|
| 294 |
+
|
| 295 |
+
if self.text_encoder_loras:
|
| 296 |
+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
| 297 |
+
if text_encoder_lr is not None:
|
| 298 |
+
param_data["lr"] = text_encoder_lr
|
| 299 |
+
all_params.append(param_data)
|
| 300 |
+
|
| 301 |
+
if self.unet_loras:
|
| 302 |
+
param_data = {"params": enumerate_params(self.unet_loras)}
|
| 303 |
+
if unet_lr is not None:
|
| 304 |
+
param_data["lr"] = unet_lr
|
| 305 |
+
all_params.append(param_data)
|
| 306 |
+
|
| 307 |
+
return all_params
|
| 308 |
+
|
| 309 |
+
def enable_gradient_checkpointing(self):
|
| 310 |
+
pass
|
| 311 |
+
|
| 312 |
+
def get_trainable_params(self):
|
| 313 |
+
return self.parameters()
|
| 314 |
+
|
| 315 |
+
def save_weights(self, file, dtype, metadata):
|
| 316 |
+
if metadata is not None and len(metadata) == 0:
|
| 317 |
+
metadata = None
|
| 318 |
+
|
| 319 |
+
state_dict = self.state_dict()
|
| 320 |
+
|
| 321 |
+
if dtype is not None:
|
| 322 |
+
for key in list(state_dict.keys()):
|
| 323 |
+
v = state_dict[key]
|
| 324 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
| 325 |
+
state_dict[key] = v
|
| 326 |
+
|
| 327 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 328 |
+
from safetensors.torch import save_file
|
| 329 |
+
|
| 330 |
+
# Precalculate model hashes to save time on indexing
|
| 331 |
+
if metadata is None:
|
| 332 |
+
metadata = {}
|
| 333 |
+
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
|
| 334 |
+
metadata["sshs_model_hash"] = model_hash
|
| 335 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
| 336 |
+
|
| 337 |
+
save_file(state_dict, file, metadata)
|
| 338 |
+
else:
|
| 339 |
+
torch.save(state_dict, file)
|
| 340 |
+
|
| 341 |
+
def create_network(
|
| 342 |
+
multiplier: float,
|
| 343 |
+
network_dim: Optional[int],
|
| 344 |
+
network_alpha: Optional[float],
|
| 345 |
+
text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
|
| 346 |
+
transformer,
|
| 347 |
+
neuron_dropout: Optional[float] = None,
|
| 348 |
+
skip_name: str = None,
|
| 349 |
+
**kwargs,
|
| 350 |
+
):
|
| 351 |
+
if network_dim is None:
|
| 352 |
+
network_dim = 4 # default
|
| 353 |
+
if network_alpha is None:
|
| 354 |
+
network_alpha = 1.0
|
| 355 |
+
|
| 356 |
+
network = LoRANetwork(
|
| 357 |
+
text_encoder,
|
| 358 |
+
transformer,
|
| 359 |
+
multiplier=multiplier,
|
| 360 |
+
lora_dim=network_dim,
|
| 361 |
+
alpha=network_alpha,
|
| 362 |
+
dropout=neuron_dropout,
|
| 363 |
+
skip_name=skip_name,
|
| 364 |
+
varbose=True,
|
| 365 |
+
)
|
| 366 |
+
return network
|
| 367 |
+
|
| 368 |
+
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
|
| 369 |
+
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
| 370 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 371 |
+
if state_dict is None:
|
| 372 |
+
state_dict = load_file(lora_path, device=device)
|
| 373 |
+
else:
|
| 374 |
+
state_dict = state_dict
|
| 375 |
+
updates = defaultdict(dict)
|
| 376 |
+
for key, value in state_dict.items():
|
| 377 |
+
layer, elem = key.split('.', 1)
|
| 378 |
+
updates[layer][elem] = value
|
| 379 |
+
|
| 380 |
+
sequential_cpu_offload_flag = False
|
| 381 |
+
if pipeline.transformer.device == torch.device(type="meta"):
|
| 382 |
+
pipeline.remove_all_hooks()
|
| 383 |
+
sequential_cpu_offload_flag = True
|
| 384 |
+
offload_device = pipeline._offload_device
|
| 385 |
+
|
| 386 |
+
for layer, elems in updates.items():
|
| 387 |
+
|
| 388 |
+
if "lora_te" in layer:
|
| 389 |
+
if transformer_only:
|
| 390 |
+
continue
|
| 391 |
+
else:
|
| 392 |
+
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
| 393 |
+
curr_layer = pipeline.text_encoder
|
| 394 |
+
else:
|
| 395 |
+
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
|
| 396 |
+
curr_layer = pipeline.transformer
|
| 397 |
+
|
| 398 |
+
try:
|
| 399 |
+
curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
|
| 400 |
+
except Exception:
|
| 401 |
+
temp_name = layer_infos.pop(0)
|
| 402 |
+
while len(layer_infos) > -1:
|
| 403 |
+
try:
|
| 404 |
+
curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
|
| 405 |
+
break
|
| 406 |
+
except Exception:
|
| 407 |
+
try:
|
| 408 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
| 409 |
+
if len(layer_infos) > 0:
|
| 410 |
+
temp_name = layer_infos.pop(0)
|
| 411 |
+
elif len(layer_infos) == 0:
|
| 412 |
+
break
|
| 413 |
+
except Exception:
|
| 414 |
+
if len(layer_infos) == 0:
|
| 415 |
+
print('Error loading layer')
|
| 416 |
+
if len(temp_name) > 0:
|
| 417 |
+
temp_name += "_" + layer_infos.pop(0)
|
| 418 |
+
else:
|
| 419 |
+
temp_name = layer_infos.pop(0)
|
| 420 |
+
|
| 421 |
+
origin_dtype = curr_layer.weight.data.dtype
|
| 422 |
+
origin_device = curr_layer.weight.data.device
|
| 423 |
+
|
| 424 |
+
curr_layer = curr_layer.to(device, dtype)
|
| 425 |
+
weight_up = elems['lora_up.weight'].to(device, dtype)
|
| 426 |
+
weight_down = elems['lora_down.weight'].to(device, dtype)
|
| 427 |
+
|
| 428 |
+
if 'alpha' in elems.keys():
|
| 429 |
+
alpha = elems['alpha'].item() / weight_up.shape[1]
|
| 430 |
+
else:
|
| 431 |
+
alpha = 1.0
|
| 432 |
+
|
| 433 |
+
if len(weight_up.shape) == 4:
|
| 434 |
+
curr_layer.weight.data += multiplier * alpha * torch.mm(
|
| 435 |
+
weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
|
| 436 |
+
).unsqueeze(2).unsqueeze(3)
|
| 437 |
+
else:
|
| 438 |
+
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
|
| 439 |
+
curr_layer = curr_layer.to(origin_device, origin_dtype)
|
| 440 |
+
|
| 441 |
+
if sequential_cpu_offload_flag:
|
| 442 |
+
pipeline.enable_sequential_cpu_offload(device=offload_device)
|
| 443 |
+
return pipeline
|
| 444 |
+
|
| 445 |
+
# TODO: Refactor with merge_lora.
|
| 446 |
+
def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
|
| 447 |
+
"""Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
|
| 448 |
+
LORA_PREFIX_UNET = "lora_unet"
|
| 449 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 450 |
+
state_dict = load_file(lora_path, device=device)
|
| 451 |
+
|
| 452 |
+
updates = defaultdict(dict)
|
| 453 |
+
for key, value in state_dict.items():
|
| 454 |
+
layer, elem = key.split('.', 1)
|
| 455 |
+
updates[layer][elem] = value
|
| 456 |
+
|
| 457 |
+
sequential_cpu_offload_flag = False
|
| 458 |
+
if pipeline.transformer.device == torch.device(type="meta"):
|
| 459 |
+
pipeline.remove_all_hooks()
|
| 460 |
+
sequential_cpu_offload_flag = True
|
| 461 |
+
|
| 462 |
+
for layer, elems in updates.items():
|
| 463 |
+
|
| 464 |
+
if "lora_te" in layer:
|
| 465 |
+
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
| 466 |
+
curr_layer = pipeline.text_encoder
|
| 467 |
+
else:
|
| 468 |
+
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
| 469 |
+
curr_layer = pipeline.transformer
|
| 470 |
+
|
| 471 |
+
try:
|
| 472 |
+
curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
|
| 473 |
+
except Exception:
|
| 474 |
+
temp_name = layer_infos.pop(0)
|
| 475 |
+
while len(layer_infos) > -1:
|
| 476 |
+
try:
|
| 477 |
+
curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
|
| 478 |
+
break
|
| 479 |
+
except Exception:
|
| 480 |
+
try:
|
| 481 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
| 482 |
+
if len(layer_infos) > 0:
|
| 483 |
+
temp_name = layer_infos.pop(0)
|
| 484 |
+
elif len(layer_infos) == 0:
|
| 485 |
+
break
|
| 486 |
+
except Exception:
|
| 487 |
+
if len(layer_infos) == 0:
|
| 488 |
+
print('Error loading layer')
|
| 489 |
+
if len(temp_name) > 0:
|
| 490 |
+
temp_name += "_" + layer_infos.pop(0)
|
| 491 |
+
else:
|
| 492 |
+
temp_name = layer_infos.pop(0)
|
| 493 |
+
|
| 494 |
+
origin_dtype = curr_layer.weight.data.dtype
|
| 495 |
+
origin_device = curr_layer.weight.data.device
|
| 496 |
+
|
| 497 |
+
curr_layer = curr_layer.to(device, dtype)
|
| 498 |
+
weight_up = elems['lora_up.weight'].to(device, dtype)
|
| 499 |
+
weight_down = elems['lora_down.weight'].to(device, dtype)
|
| 500 |
+
|
| 501 |
+
if 'alpha' in elems.keys():
|
| 502 |
+
alpha = elems['alpha'].item() / weight_up.shape[1]
|
| 503 |
+
else:
|
| 504 |
+
alpha = 1.0
|
| 505 |
+
|
| 506 |
+
if len(weight_up.shape) == 4:
|
| 507 |
+
curr_layer.weight.data -= multiplier * alpha * torch.mm(
|
| 508 |
+
weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
|
| 509 |
+
).unsqueeze(2).unsqueeze(3)
|
| 510 |
+
else:
|
| 511 |
+
curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
|
| 512 |
+
curr_layer = curr_layer.to(origin_device, origin_dtype)
|
| 513 |
+
|
| 514 |
+
if sequential_cpu_offload_flag:
|
| 515 |
+
pipeline.enable_sequential_cpu_offload(device=device)
|
| 516 |
+
return pipeline
|