Dependancies
diffusers==0.30.1
peft==0.17.1
transformers==4.44
imageio
imageio-ffmpeg==0.5.1
pandas
wandb
accelerate
matplotlib
sentencepiece
yt-dlp
datasets
torch
torchvision
av
ffmpeg
Run the Code
import torch
import torch.nn as nn
import numpy as np
import imageio
from PIL import Image
import torchvision.transforms as T
from diffusers import CogVideoXPipeline
from huggingface_hub import PyTorchModelHubMixin
from peft import LoraConfig, get_peft_model, TaskType
from diffusers import CogVideoXTransformer3DModel
class TruncatedCogVideoX(nn.Module):
def __init__(self, n_blocks=6, out_dim=1, reward_backbone=False):
super().__init__()
self.reward_backbone = reward_backbone
pretrained_model = CogVideoXTransformer3DModel.from_pretrained(
"zai-org/CogVideoX-2b",
subfolder='transformer',
torch_dtype=torch.float32
)
# Use pretrained CogVideoX backbone
self.patch_embed = pretrained_model.patch_embed
self.embedding_dropout = pretrained_model.embedding_dropout
self.time_proj = pretrained_model.time_proj
self.time_embedding = pretrained_model.time_embedding
self.config = pretrained_model.config
self.register_buffer("pos_embedding", pretrained_model.pos_embedding, persistent=False)
self.gradient_checkpointing = False
# Keep first n blocks
self.transformer_blocks = nn.ModuleList(
[pretrained_model.transformer_blocks[i] for i in range(n_blocks)]
)
# Norm after transformer
self.norm_final = pretrained_model.norm_final
#self.norm_out = pretrained_model.norm_out
#self.proj_out = pretrained_model.proj_out
# Classification head
self.pool = nn.AdaptiveAvgPool1d(1) # global average pooling over sequence
if reward_backbone:
self.fc = None
else:
self.fc = nn.Linear(1920, out_dim) # backbone hidden dim = 1920
del pretrained_model
def forward(self, hidden_states, timesteps, encoder_hidden_states, image_rotary_emb=None):
"""
hidden_states: (B, C, T, H, W) video input tensor
"""
batch_size, num_frames, channels, height, width = hidden_states.shape
# Time embeddings
t_emb = self.time_proj(timesteps).to(dtype=torch.float16)
emb = self.time_embedding(t_emb, None)
# Patch embeddings
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
for i, block in enumerate(self.transformer_blocks):
if self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
hidden_states = hidden_states.view(batch_size, num_frames, -1, 1920).permute(0, 1, 3, 2)
hidden_states = self.pool(hidden_states.reshape(batch_size, num_frames * 1920, -1))
if self.reward_backbone:
return hidden_states.view(batch_size, num_frames, 1920)[:, -1, :]
output = self.fc(hidden_states.view(batch_size, num_frames, 1920))
return output
class LoRACogVideoX(TruncatedCogVideoX):
def __init__(self, n_blocks, lora_layers, out_dim=1, r=8, lora_alpha=16, reward_backbone=False):
super().__init__(n_blocks=n_blocks, out_dim=out_dim, reward_backbone=reward_backbone)
target_modules = []
for i in lora_layers:
if i < len(self.transformer_blocks):
target_modules.extend([
f"transformer_blocks.{i}.attn1.to_q", f"transformer_blocks.{i}.attn1.to_k",
f"transformer_blocks.{i}.attn1.to_v", f"transformer_blocks.{i}.attn1.to_out.0",
f"transformer_blocks.{i}.attn2.to_q", f"transformer_blocks.{i}.attn2.to_k",
f"transformer_blocks.{i}.attn2.to_v", f"transformer_blocks.{i}.attn2.to_out.0",
])
lora_config = LoraConfig(
r=r, lora_alpha=lora_alpha, target_modules=target_modules,
lora_dropout=0.05, bias="none", task_type=TaskType.FEATURE_EXTRACTION
)
self = get_peft_model(self, lora_config)
class ModelWrapper(nn.Module, PyTorchModelHubMixin):
def __init__(self, model_type, model_args, num_frames, mode='linear',
reward_hidden_size=128, reward_final_size=3, use_reward_layer=False):
super().__init__()
self.backbone = LoRACogVideoX(*model_args)
self.use_reward_layer = use_reward_layer
if use_reward_layer:
self.reward_layer = nn.Sequential(
nn.Linear(1920, reward_hidden_size),
nn.SiLU(),
nn.Linear(reward_hidden_size, reward_final_size),
)
self.use_linear = (mode == 'linear')
if self.use_linear:
self.linear = nn.Linear(reward_final_size, 1)
def forward(self, hidden_states, timesteps, encoder_hidden_states, image_rotary_emb=None):
out = self.backbone(hidden_states, timesteps, encoder_hidden_states, image_rotary_emb)
if self.use_reward_layer:
out = self.reward_layer(out.squeeze(-1))
if self.use_linear:
out = self.linear(out)
return out
def video_to_latent_tensor(
video_path,
prompt,
pipe=None,
device="cuda",
h=480,
w=720,
num_frames=49,
dtype=torch.float16
):
"""
Convert a single video into VAE latent tensor.
Args:
video_path (str): path to video file
pipe (CogVideoXPipeline, optional): pass preloaded pipeline to avoid reloading
device (str): cuda / cpu
h, w (int): resize dimensions
num_frames (int): number of frames to sample
dtype: torch dtype
Returns:
latents (torch.Tensor): shape [1, C, T, H', W']
"""
# Load pipeline if not provided (better to reuse externally)
if pipe is None:
pipe = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX-2b",
torch_dtype=dtype
).to(device)
vae = pipe.vae
text_encoder = pipe.text_encoder.eval()
tokenizer = pipe.tokenizer
vae.eval()
transform = T.Compose([
T.Resize((h, w)),
T.ToTensor(),
T.Normalize([0.5], [0.5])
])
# --- Read video ---
reader = imageio.get_reader(video_path, format="ffmpeg")
frames = [Image.fromarray(frame) for frame in reader]
reader.close()
total_frames = len(frames)
if total_frames == 0:
raise ValueError("Video has no frames")
# --- Frame sampling ---
if total_frames < num_frames:
indices = np.linspace(0, total_frames - 1, total_frames, dtype=int)
else:
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = [transform(frames[i]) for i in indices]
# --- Tensor shape: [1, C, T, H, W] ---
video_tensor = torch.stack(sampled_frames, dim=1).unsqueeze(0).to(device, dtype=dtype)
# --- Encode ---
with torch.no_grad():
latents = vae.encode(video_tensor).latent_dist.sample()
latents = latents * vae.config.scaling_factor
input_ids = tokenizer(
prompt,
padding="max_length",
max_length=226,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).input_ids
prompt_embeds = text_encoder(input_ids.cuda())[0]
return latents, prompt_embeds
if __name__ == '__main__':
# Load the model
video_path = 'common_videos_gen3_instances/instance_00000/gen3_00102.mp4'
prompt = 'A ball rolling on a table'
video_input, prompt_embeds = video_to_latent_tensor(video_path, prompt)
torch.cuda.empty_cache()
timestep = torch.zeros((video_input.shape[0],), dtype=torch.long, device=video_input.device)
video_input = video_input.permute(0, 2, 1, 3, 4)
model = ModelWrapper.from_pretrained("sasuke-ss1/GT-SVJ")
model.half().cuda()
model.eval()
with torch.no_grad():
output = model(video_input, timestep, prompt_embeds)
print(output)