Dependancies

diffusers==0.30.1
peft==0.17.1
transformers==4.44
imageio
imageio-ffmpeg==0.5.1
pandas
wandb
accelerate
matplotlib
sentencepiece
yt-dlp
datasets
torch
torchvision
av
ffmpeg

Run the Code

import torch
import torch.nn as nn
import numpy as np
import imageio
from PIL import Image
import torchvision.transforms as T
from diffusers import CogVideoXPipeline
from huggingface_hub import PyTorchModelHubMixin
from peft import LoraConfig, get_peft_model, TaskType
from diffusers import CogVideoXTransformer3DModel


class TruncatedCogVideoX(nn.Module):
    def __init__(self, n_blocks=6, out_dim=1, reward_backbone=False):
        super().__init__()
        self.reward_backbone = reward_backbone

        pretrained_model = CogVideoXTransformer3DModel.from_pretrained(
            "zai-org/CogVideoX-2b",
            subfolder='transformer',
            torch_dtype=torch.float32
        )
        # Use pretrained CogVideoX backbone
        self.patch_embed = pretrained_model.patch_embed
        self.embedding_dropout = pretrained_model.embedding_dropout
        self.time_proj = pretrained_model.time_proj
        self.time_embedding = pretrained_model.time_embedding
        self.config = pretrained_model.config
        self.register_buffer("pos_embedding", pretrained_model.pos_embedding, persistent=False)

        self.gradient_checkpointing = False

        # Keep first n blocks
        self.transformer_blocks = nn.ModuleList(
            [pretrained_model.transformer_blocks[i] for i in range(n_blocks)]
        )
        
        # Norm after transformer
        self.norm_final = pretrained_model.norm_final
        #self.norm_out = pretrained_model.norm_out
        #self.proj_out = pretrained_model.proj_out


        # Classification head
        self.pool = nn.AdaptiveAvgPool1d(1)   # global average pooling over sequence
        
        if reward_backbone:
            self.fc = None
        else:
            self.fc = nn.Linear(1920, out_dim)    # backbone hidden dim = 1920
        
        del pretrained_model

    def forward(self, hidden_states, timesteps, encoder_hidden_states, image_rotary_emb=None):
        """
        hidden_states: (B, C, T, H, W) video input tensor
        """
        batch_size, num_frames, channels, height, width = hidden_states.shape

        # Time embeddings
        t_emb = self.time_proj(timesteps).to(dtype=torch.float16)
        emb = self.time_embedding(t_emb, None)

        # Patch embeddings
        hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)

        # Position embedding
        text_seq_length = encoder_hidden_states.shape[1]
        
        if not self.config.use_rotary_positional_embeddings:
            seq_length = height * width * num_frames // (self.config.patch_size**2)

            pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
            hidden_states = hidden_states + pos_embeds
            hidden_states = self.embedding_dropout(hidden_states)

        encoder_hidden_states = hidden_states[:, :text_seq_length]
        hidden_states = hidden_states[:, text_seq_length:]


        for i, block in enumerate(self.transformer_blocks):
            if self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                ckpt_kwargs = {"use_reentrant": False}
                hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    encoder_hidden_states,
                    emb,
                    image_rotary_emb,
                    **ckpt_kwargs,
                )
            else:
                hidden_states, encoder_hidden_states = block(
                    hidden_states=hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    temb=emb,
                    image_rotary_emb=image_rotary_emb,
                )
            
        if not self.config.use_rotary_positional_embeddings:
            # CogVideoX-2B
            hidden_states = self.norm_final(hidden_states)
        else:
            # CogVideoX-5B
            hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
            hidden_states = self.norm_final(hidden_states)
            hidden_states = hidden_states[:, text_seq_length:]
        

        hidden_states = hidden_states.view(batch_size, num_frames, -1, 1920).permute(0, 1, 3, 2)
        hidden_states = self.pool(hidden_states.reshape(batch_size, num_frames * 1920, -1))

        if self.reward_backbone:
            return hidden_states.view(batch_size, num_frames, 1920)[:, -1, :]
        
        output = self.fc(hidden_states.view(batch_size, num_frames, 1920))

        return output


class LoRACogVideoX(TruncatedCogVideoX):
    def __init__(self, n_blocks, lora_layers, out_dim=1, r=8, lora_alpha=16, reward_backbone=False):
        super().__init__(n_blocks=n_blocks, out_dim=out_dim, reward_backbone=reward_backbone)
        target_modules = []
        for i in lora_layers:
            if i < len(self.transformer_blocks):
                target_modules.extend([
                    f"transformer_blocks.{i}.attn1.to_q", f"transformer_blocks.{i}.attn1.to_k",
                    f"transformer_blocks.{i}.attn1.to_v", f"transformer_blocks.{i}.attn1.to_out.0",
                    f"transformer_blocks.{i}.attn2.to_q", f"transformer_blocks.{i}.attn2.to_k",
                    f"transformer_blocks.{i}.attn2.to_v", f"transformer_blocks.{i}.attn2.to_out.0",
                ])
        lora_config = LoraConfig(
            r=r, lora_alpha=lora_alpha, target_modules=target_modules,
            lora_dropout=0.05, bias="none", task_type=TaskType.FEATURE_EXTRACTION
        )
        self = get_peft_model(self, lora_config)


class ModelWrapper(nn.Module, PyTorchModelHubMixin):
    def __init__(self, model_type, model_args, num_frames, mode='linear',
                 reward_hidden_size=128, reward_final_size=3, use_reward_layer=False):
        super().__init__()
        self.backbone = LoRACogVideoX(*model_args)
        self.use_reward_layer = use_reward_layer
        if use_reward_layer:
            self.reward_layer = nn.Sequential(
                nn.Linear(1920, reward_hidden_size),
                nn.SiLU(),
                nn.Linear(reward_hidden_size, reward_final_size),
            )
        self.use_linear = (mode == 'linear')
        if self.use_linear:
            self.linear = nn.Linear(reward_final_size, 1)

    def forward(self, hidden_states, timesteps, encoder_hidden_states, image_rotary_emb=None):
        out = self.backbone(hidden_states, timesteps, encoder_hidden_states, image_rotary_emb)
        if self.use_reward_layer:
            out = self.reward_layer(out.squeeze(-1))
        if self.use_linear:
            out = self.linear(out)
        return out





def video_to_latent_tensor(
    video_path,
    prompt,
    pipe=None,
    device="cuda",
    h=480,
    w=720,
    num_frames=49,
    dtype=torch.float16
):
    """
    Convert a single video into VAE latent tensor.

    Args:
        video_path (str): path to video file
        pipe (CogVideoXPipeline, optional): pass preloaded pipeline to avoid reloading
        device (str): cuda / cpu
        h, w (int): resize dimensions
        num_frames (int): number of frames to sample
        dtype: torch dtype

    Returns:
        latents (torch.Tensor): shape [1, C, T, H', W']
    """

    # Load pipeline if not provided (better to reuse externally)
    if pipe is None:
        pipe = CogVideoXPipeline.from_pretrained(
            "THUDM/CogVideoX-2b",
            torch_dtype=dtype
        ).to(device)

    vae = pipe.vae
    text_encoder = pipe.text_encoder.eval()
    tokenizer = pipe.tokenizer
    vae.eval()

    transform = T.Compose([
        T.Resize((h, w)),
        T.ToTensor(),
        T.Normalize([0.5], [0.5])
    ])

    # --- Read video ---
    reader = imageio.get_reader(video_path, format="ffmpeg")
    frames = [Image.fromarray(frame) for frame in reader]
    reader.close()

    total_frames = len(frames)
    if total_frames == 0:
        raise ValueError("Video has no frames")

    # --- Frame sampling ---
    if total_frames < num_frames:
        indices = np.linspace(0, total_frames - 1, total_frames, dtype=int)
    else:
        indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)

    sampled_frames = [transform(frames[i]) for i in indices]

    # --- Tensor shape: [1, C, T, H, W] ---
    video_tensor = torch.stack(sampled_frames, dim=1).unsqueeze(0).to(device, dtype=dtype)

    # --- Encode ---
    with torch.no_grad():
        latents = vae.encode(video_tensor).latent_dist.sample()
        latents = latents * vae.config.scaling_factor
        
        input_ids = tokenizer(
            prompt,
            padding="max_length",
            max_length=226,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        ).input_ids
        prompt_embeds = text_encoder(input_ids.cuda())[0]

    return latents, prompt_embeds


if __name__ == '__main__':
    # Load the model
    video_path = 'common_videos_gen3_instances/instance_00000/gen3_00102.mp4'
    prompt = 'A ball rolling on a table'
    video_input, prompt_embeds = video_to_latent_tensor(video_path, prompt)
    
    
    torch.cuda.empty_cache()
    
    timestep = torch.zeros((video_input.shape[0],), dtype=torch.long, device=video_input.device)
    video_input = video_input.permute(0, 2, 1, 3, 4)
    
    model = ModelWrapper.from_pretrained("sasuke-ss1/GT-SVJ")
    model.half().cuda()
    model.eval()
    
    with torch.no_grad():
        output = model(video_input, timestep, prompt_embeds)
    print(output)
Downloads last month
67
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support