| import os |
| import sys |
| from pathlib import Path |
| from typing import Optional |
|
|
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
| sys.path.append(ROOT_DIR) |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from PIL import Image |
| import imageio |
| import json |
| from diffsynth import WanVideoAstraPipeline, ModelManager |
| import argparse |
| from torchvision.transforms import v2 |
| from einops import rearrange |
| from scipy.spatial.transform import Rotation as R |
| import random |
| import copy |
| from datetime import datetime |
|
|
| VALID_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg"} |
| class InlineVideoEncoder: |
|
|
| def __init__(self, pipe: WanVideoAstraPipeline, device="cuda"): |
| self.device = getattr(pipe, "device", device) |
| self.tiler_kwargs = {"tiled": True, "tile_size": (34, 34), "tile_stride": (18, 16)} |
| self.frame_process = v2.Compose([ |
| v2.ToTensor(), |
| v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| ]) |
|
|
| self.pipe = pipe |
|
|
| @staticmethod |
| def _crop_and_resize(image: Image.Image) -> Image.Image: |
| target_w, target_h = 832, 480 |
| return v2.functional.resize( |
| image, |
| (round(target_h), round(target_w)), |
| interpolation=v2.InterpolationMode.BILINEAR, |
| ) |
|
|
| def preprocess_frame(self, image: Image.Image) -> torch.Tensor: |
| image = image.convert("RGB") |
| image = self._crop_and_resize(image) |
| return self.frame_process(image) |
|
|
| def load_video_frames(self, video_path: Path) -> Optional[torch.Tensor]: |
| reader = imageio.get_reader(str(video_path)) |
| frames = [] |
| for frame_data in reader: |
| frame = Image.fromarray(frame_data) |
| frames.append(self.preprocess_frame(frame)) |
| reader.close() |
|
|
| if not frames: |
| return None |
|
|
| frames = torch.stack(frames, dim=0) |
| return rearrange(frames, "T C H W -> C T H W") |
|
|
| def encode_frames_to_latents(self, frames: torch.Tensor) -> torch.Tensor: |
| frames = frames.unsqueeze(0).to(self.device, dtype=torch.bfloat16) |
| with torch.no_grad(): |
| latents = self.pipe.encode_video(frames, **self.tiler_kwargs)[0] |
|
|
| if latents.dim() == 5 and latents.shape[0] == 1: |
| latents = latents.squeeze(0) |
| return latents.cpu() |
| |
| def image_to_frame_stack( |
| image_path: Path, |
| encoder: InlineVideoEncoder, |
| repeat_count: int = 10 |
| ) -> torch.Tensor: |
| """Repeat a single image into a tensor with specified number of frames, shape [C, T, H, W]""" |
| if image_path.suffix.lower() not in VALID_IMAGE_EXTENSIONS: |
| raise ValueError(f"Unsupported image format: {image_path.suffix}") |
|
|
| image = Image.open(str(image_path)) |
| frame = encoder.preprocess_frame(image) |
| frames = torch.stack([frame for _ in range(repeat_count)], dim=0) |
| return rearrange(frames, "T C H W -> C T H W") |
|
|
|
|
| def load_or_encode_condition( |
| condition_pth_path: Optional[str], |
| condition_video: Optional[str], |
| condition_image: Optional[str], |
| start_frame: int, |
| num_frames: int, |
| device: str, |
| pipe: WanVideoAstraPipeline, |
| ) -> tuple[torch.Tensor, dict]: |
| if condition_pth_path: |
| return load_encoded_video_from_pth(condition_pth_path, start_frame, num_frames) |
|
|
| encoder = InlineVideoEncoder(pipe=pipe, device=device) |
|
|
| if condition_video: |
| video_path = Path(condition_video).expanduser().resolve() |
| if not video_path.exists(): |
| raise FileNotFoundError(f"File not Found: {video_path}") |
| frames = encoder.load_video_frames(video_path) |
| if frames is None: |
| raise ValueError(f"no valid frames in {video_path}") |
| elif condition_image: |
| image_path = Path(condition_image).expanduser().resolve() |
| if not image_path.exists(): |
| raise FileNotFoundError(f"File not Found: {image_path}") |
| frames = image_to_frame_stack(image_path, encoder, repeat_count=10) |
| else: |
| raise ValueError("condition video or image is needed for video generation.") |
|
|
| latents = encoder.encode_frames_to_latents(frames) |
| encoded_data = {"latents": latents} |
|
|
| if start_frame + num_frames > latents.shape[1]: |
| raise ValueError( |
| f"Not enough frames after encoding: requested {start_frame + num_frames}, available {latents.shape[1]}" |
| ) |
|
|
| condition_latents = latents[:, start_frame:start_frame + num_frames, :, :] |
| return condition_latents, encoded_data |
|
|
|
|
|
|
| def compute_relative_pose_matrix(pose1, pose2): |
| """ |
| Compute relative pose between two consecutive frames, return 3x4 camera matrix [R_rel | t_rel] |
| |
| Args: |
| pose1: Camera pose of frame i, shape (7,) array [tx1, ty1, tz1, qx1, qy1, qz1, qw1] |
| pose2: Camera pose of frame i+1, shape (7,) array [tx2, ty2, tz2, qx2, qy2, qz2, qw2] |
| |
| Returns: |
| relative_matrix: 3x4 relative pose matrix, |
| first 3 columns are rotation matrix R_rel, |
| last column is translation vector t_rel |
| """ |
| |
| t1 = pose1[:3] |
| q1 = pose1[3:] |
| t2 = pose2[:3] |
| q2 = pose2[3:] |
| |
| |
| rot1 = R.from_quat(q1) |
| rot2 = R.from_quat(q2) |
| rot_rel = rot2 * rot1.inv() |
| R_rel = rot_rel.as_matrix() |
| |
| |
| R1_T = rot1.as_matrix().T |
| t_rel = R1_T @ (t2 - t1) |
| |
| |
| relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) |
| |
| return relative_matrix |
|
|
| def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): |
| """Load pre-encoded video data from pth file""" |
| print(f"Loading encoded video from {pth_path}") |
| |
| encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") |
| full_latents = encoded_data['latents'] |
| |
| print(f"Full latents shape: {full_latents.shape}") |
| print(f"Extracting frames {start_frame} to {start_frame + num_frames}") |
| |
| if start_frame + num_frames > full_latents.shape[1]: |
| raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") |
| |
| condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] |
| print(f"Extracted condition latents shape: {condition_latents.shape}") |
| |
| return condition_latents, encoded_data |
|
|
| def compute_relative_pose(pose_a, pose_b, use_torch=False): |
| """Compute relative pose matrix of camera B with respect to camera A""" |
| assert pose_a.shape == (4, 4), f"Camera A extrinsic matrix should be (4,4), got {pose_a.shape}" |
| assert pose_b.shape == (4, 4), f"Camera B extrinsic matrix should be (4,4), got {pose_b.shape}" |
| |
| if use_torch: |
| if not isinstance(pose_a, torch.Tensor): |
| pose_a = torch.from_numpy(pose_a).float() |
| if not isinstance(pose_b, torch.Tensor): |
| pose_b = torch.from_numpy(pose_b).float() |
| |
| pose_a_inv = torch.inverse(pose_a) |
| relative_pose = torch.matmul(pose_b, pose_a_inv) |
| else: |
| if not isinstance(pose_a, np.ndarray): |
| pose_a = np.array(pose_a, dtype=np.float32) |
| if not isinstance(pose_b, np.ndarray): |
| pose_b = np.array(pose_b, dtype=np.float32) |
| |
| pose_a_inv = np.linalg.inv(pose_a) |
| relative_pose = np.matmul(pose_b, pose_a_inv) |
| |
| return relative_pose |
|
|
|
|
| def replace_dit_model_in_manager(): |
| """Replace DiT model class with MoE version""" |
| from diffsynth.models.wan_video_dit_moe import WanModelMoe |
| from diffsynth.configs.model_config import model_loader_configs |
| |
| for i, config in enumerate(model_loader_configs): |
| keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config |
| |
| if 'wan_video_dit' in model_names: |
| new_model_names = [] |
| new_model_classes = [] |
| |
| for name, cls in zip(model_names, model_classes): |
| if name == 'wan_video_dit': |
| new_model_names.append(name) |
| new_model_classes.append(WanModelMoe) |
| print(f"Replaced model class: {name} -> WanModelMoe") |
| else: |
| new_model_names.append(name) |
| new_model_classes.append(cls) |
| |
| model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) |
|
|
|
|
| def add_framepack_components(dit_model): |
| """Add FramePack related components""" |
| if not hasattr(dit_model, 'clean_x_embedder'): |
| inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] |
| |
| class CleanXEmbedder(nn.Module): |
| def __init__(self, inner_dim): |
| super().__init__() |
| self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) |
| self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) |
| self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) |
| |
| def forward(self, x, scale="1x"): |
| if scale == "1x": |
| x = x.to(self.proj.weight.dtype) |
| return self.proj(x) |
| elif scale == "2x": |
| x = x.to(self.proj_2x.weight.dtype) |
| return self.proj_2x(x) |
| elif scale == "4x": |
| x = x.to(self.proj_4x.weight.dtype) |
| return self.proj_4x(x) |
| else: |
| raise ValueError(f"Unsupported scale: {scale}") |
| |
| dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) |
| model_dtype = next(dit_model.parameters()).dtype |
| dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) |
| print("Added FramePack clean_x_embedder component") |
|
|
|
|
| def add_moe_components(dit_model, moe_config): |
| """Add MoE related components - corrected version""" |
| if not hasattr(dit_model, 'moe_config'): |
| dit_model.moe_config = moe_config |
| print("Added MoE config to model") |
| dit_model.top_k = moe_config.get("top_k", 1) |
|
|
| |
| dim = dit_model.blocks[0].self_attn.q.weight.shape[0] |
| unified_dim = moe_config.get("unified_dim", 25) |
| num_experts = moe_config.get("num_experts", 4) |
| from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE |
| dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) |
| dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) |
| dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) |
| dit_model.global_router = nn.Linear(unified_dim, num_experts) |
|
|
|
|
| for i, block in enumerate(dit_model.blocks): |
| |
| block.moe = MultiModalMoE( |
| unified_dim=unified_dim, |
| output_dim=dim, |
| num_experts=moe_config.get("num_experts", 4), |
| top_k=moe_config.get("top_k", 2) |
| ) |
| |
| print(f"Block {i} added MoE component (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})") |
|
|
|
|
| def generate_sekai_camera_embeddings_sliding( |
| cam_data, |
| start_frame, |
| initial_condition_frames, |
| new_frames, |
| total_generated, |
| use_real_poses=True, |
| direction="left"): |
| """ |
| Generate camera embeddings for Sekai dataset - sliding window version |
| |
| Args: |
| cam_data: Dictionary containing Sekai camera extrinsic parameters, key 'extrinsic' corresponds to an N*4*4 numpy array |
| start_frame: Current generation start frame index |
| initial_condition_frames: Initial condition frame count |
| new_frames: Number of new frames to generate this time |
| total_generated: Total frames already generated |
| use_real_poses: Whether to use real Sekai camera poses |
| direction: Camera movement direction, default "left" |
| |
| Returns: |
| camera_embedding: Torch tensor of shape (M, 3*4 + 1), where M is the total number of generated frames |
| """ |
| time_compression_ratio = 4 |
| |
| |
| |
| framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames |
| |
| if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: |
| print("🔧 Using real Sekai camera data") |
| cam_extrinsic = cam_data['extrinsic'] |
| |
| |
| max_needed_frames = max( |
| start_frame + initial_condition_frames + new_frames, |
| framepack_needed_frames, |
| 30 |
| ) |
| |
| print(f"🔧 Calculating Sekai camera sequence length:") |
| print(f" - Basic requirement: {start_frame + initial_condition_frames + new_frames}") |
| print(f" - FramePack requirement: {framepack_needed_frames}") |
| print(f" - Final generation: {max_needed_frames}") |
| |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| |
| frame_idx = i * time_compression_ratio |
| next_frame_idx = frame_idx + time_compression_ratio |
| |
| if next_frame_idx < len(cam_extrinsic): |
| cam_prev = cam_extrinsic[frame_idx] |
| cam_next = cam_extrinsic[next_frame_idx] |
| relative_pose = compute_relative_pose(cam_prev, cam_next) |
| relative_poses.append(torch.as_tensor(relative_pose[:3, :])) |
| else: |
| |
| print(f"⚠️ Frame {frame_idx} exceeds camera data range, using zero motion") |
| relative_poses.append(torch.zeros(3, 4)) |
| |
| pose_embedding = torch.stack(relative_poses, dim=0) |
| pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
| |
| |
| mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
| |
| condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
| mask[start_frame:condition_end] = 1.0 |
| |
| camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
| print(f"🔧 Sekai real camera embedding shape: {camera_embedding.shape}") |
| return camera_embedding.to(torch.bfloat16) |
| |
| else: |
| |
| max_needed_frames = max( |
| start_frame + initial_condition_frames + new_frames, |
| framepack_needed_frames, |
| 30) |
| |
| print(f"🔧 Generating Sekai synthetic camera frames: {max_needed_frames}") |
| |
| CONDITION_FRAMES = initial_condition_frames |
| STAGE_1 = new_frames//2 |
| STAGE_2 = new_frames - STAGE_1 |
| |
| if direction=="forward": |
| print("--------------- FORWARD MODE ---------------") |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| if i < CONDITION_FRAMES: |
| |
| pose = np.eye(4, dtype=np.float32) |
| elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
| |
| forward_speed = 0.03 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| pose[2, 3] = -forward_speed |
| else: |
| |
| pose = np.eye(4, dtype=np.float32) |
| |
| relative_pose = pose[:3, :] |
| relative_poses.append(torch.as_tensor(relative_pose)) |
| |
| elif direction=="left": |
| print("--------------- LEFT TURNING MODE ---------------") |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| if i < CONDITION_FRAMES: |
| |
| pose = np.eye(4, dtype=np.float32) |
| elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
| |
| yaw_per_frame = 0.03 |
|
|
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| forward_speed = 0.00 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| |
| pose[0, 0] = cos_yaw |
| pose[0, 2] = sin_yaw |
| pose[2, 0] = -sin_yaw |
| pose[2, 2] = cos_yaw |
| pose[2, 3] = -forward_speed |
| else: |
| |
| pose = np.eye(4, dtype=np.float32) |
| |
| relative_pose = pose[:3, :] |
| relative_poses.append(torch.as_tensor(relative_pose)) |
| |
| elif direction=="right": |
| print("--------------- RIGHT TURNING MODE ---------------") |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| if i < CONDITION_FRAMES: |
| |
| pose = np.eye(4, dtype=np.float32) |
| elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
| |
| yaw_per_frame = -0.03 |
|
|
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| forward_speed = 0.00 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| |
| pose[0, 0] = cos_yaw |
| pose[0, 2] = sin_yaw |
| pose[2, 0] = -sin_yaw |
| pose[2, 2] = cos_yaw |
| pose[2, 3] = -forward_speed |
| else: |
| |
| pose = np.eye(4, dtype=np.float32) |
| |
| relative_pose = pose[:3, :] |
| relative_poses.append(torch.as_tensor(relative_pose)) |
| |
| elif direction=="forward_left": |
| print("--------------- FORWARD LEFT MODE ---------------") |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| if i < CONDITION_FRAMES: |
| |
| pose = np.eye(4, dtype=np.float32) |
| elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
| |
| yaw_per_frame = 0.03 |
|
|
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| forward_speed = 0.03 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| |
| pose[0, 0] = cos_yaw |
| pose[0, 2] = sin_yaw |
| pose[2, 0] = -sin_yaw |
| pose[2, 2] = cos_yaw |
| pose[2, 3] = -forward_speed |
| |
| else: |
| |
| pose = np.eye(4, dtype=np.float32) |
| |
| relative_pose = pose[:3, :] |
| relative_poses.append(torch.as_tensor(relative_pose)) |
| |
| elif direction=="forward_right": |
| print("--------------- FORWARD RIGHT MODE ---------------") |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| if i < CONDITION_FRAMES: |
| |
| pose = np.eye(4, dtype=np.float32) |
| elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
| |
| yaw_per_frame = -0.03 |
|
|
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| forward_speed = 0.03 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| |
| pose[0, 0] = cos_yaw |
| pose[0, 2] = sin_yaw |
| pose[2, 0] = -sin_yaw |
| pose[2, 2] = cos_yaw |
| pose[2, 3] = -forward_speed |
| |
| else: |
| |
| pose = np.eye(4, dtype=np.float32) |
| |
| relative_pose = pose[:3, :] |
| relative_poses.append(torch.as_tensor(relative_pose)) |
| |
| elif direction=="s_curve": |
| print("--------------- S CURVE MODE ---------------") |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| if i < CONDITION_FRAMES: |
| |
| pose = np.eye(4, dtype=np.float32) |
| elif i < CONDITION_FRAMES+STAGE_1: |
| |
| yaw_per_frame = 0.03 |
|
|
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| forward_speed = 0.03 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| |
| pose[0, 0] = cos_yaw |
| pose[0, 2] = sin_yaw |
| pose[2, 0] = -sin_yaw |
| pose[2, 2] = cos_yaw |
| pose[2, 3] = -forward_speed |
| |
| elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
| |
| yaw_per_frame = -0.03 |
|
|
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| forward_speed = 0.03 |
| |
| if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3: |
| radius_shift = -0.01 |
| else: |
| radius_shift = 0.00 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| |
| pose[0, 0] = cos_yaw |
| pose[0, 2] = sin_yaw |
| pose[2, 0] = -sin_yaw |
| pose[2, 2] = cos_yaw |
| pose[2, 3] = -forward_speed |
| pose[0, 3] = radius_shift |
| |
| else: |
| |
| pose = np.eye(4, dtype=np.float32) |
| |
| relative_pose = pose[:3, :] |
| relative_poses.append(torch.as_tensor(relative_pose)) |
| |
| elif direction=="left_right": |
| print("--------------- LEFT RIGHT MODE ---------------") |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| if i < CONDITION_FRAMES: |
| |
| pose = np.eye(4, dtype=np.float32) |
| elif i < CONDITION_FRAMES+STAGE_1: |
| |
| yaw_per_frame = 0.03 |
|
|
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| forward_speed = 0.00 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| |
| pose[0, 0] = cos_yaw |
| pose[0, 2] = sin_yaw |
| pose[2, 0] = -sin_yaw |
| pose[2, 2] = cos_yaw |
| pose[2, 3] = -forward_speed |
| |
| elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
| |
| yaw_per_frame = -0.03 |
|
|
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| forward_speed = 0.00 |
|
|
| pose = np.eye(4, dtype=np.float32) |
| |
| pose[0, 0] = cos_yaw |
| pose[0, 2] = sin_yaw |
| pose[2, 0] = -sin_yaw |
| pose[2, 2] = cos_yaw |
| pose[2, 3] = -forward_speed |
| |
| else: |
| |
| pose = np.eye(4, dtype=np.float32) |
| |
| relative_pose = pose[:3, :] |
| relative_poses.append(torch.as_tensor(relative_pose)) |
| |
| else: |
| raise ValueError(f"Not Defined Direction: {direction}") |
| |
| pose_embedding = torch.stack(relative_poses, dim=0) |
| pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
| |
| |
| mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
| condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames) |
| mask[start_frame:condition_end] = 1.0 |
| |
| camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
| print(f"🔧 Sekai synthetic camera embedding shape: {camera_embedding.shape}") |
| return camera_embedding.to(torch.bfloat16) |
|
|
|
|
| def generate_openx_camera_embeddings_sliding( |
| encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses): |
| """Generate camera embeddings for OpenX dataset - sliding window version""" |
| time_compression_ratio = 4 |
| |
| |
| framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames |
| |
| if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']: |
| print("🔧 Using OpenX real camera data") |
| cam_extrinsic = encoded_data['cam_emb']['extrinsic'] |
| |
| |
| max_needed_frames = max( |
| start_frame + initial_condition_frames + new_frames, |
| framepack_needed_frames, |
| 30 |
| ) |
| |
| print(f"🔧 Calculating OpenX camera sequence length:") |
| print(f" - Basic requirement: {start_frame + initial_condition_frames + new_frames}") |
| print(f" - FramePack requirement: {framepack_needed_frames}") |
| print(f" - Final generation: {max_needed_frames}") |
| |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| |
| frame_idx = i * time_compression_ratio |
| next_frame_idx = frame_idx + time_compression_ratio |
| |
| if next_frame_idx < len(cam_extrinsic): |
| cam_prev = cam_extrinsic[frame_idx] |
| cam_next = cam_extrinsic[next_frame_idx] |
| relative_pose = compute_relative_pose(cam_prev, cam_next) |
| relative_poses.append(torch.as_tensor(relative_pose[:3, :])) |
| else: |
| |
| print(f"⚠️ Frame {frame_idx} exceeds OpenX camera data range, using zero motion") |
| relative_poses.append(torch.zeros(3, 4)) |
| |
| pose_embedding = torch.stack(relative_poses, dim=0) |
| pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
| |
| |
| mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
| |
| condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
| mask[start_frame:condition_end] = 1.0 |
| |
| camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
| print(f"🔧 OpenX real camera embedding shape: {camera_embedding.shape}") |
| return camera_embedding.to(torch.bfloat16) |
| |
| else: |
| print("🔧 Using OpenX synthetic camera data") |
| |
| max_needed_frames = max( |
| start_frame + initial_condition_frames + new_frames, |
| framepack_needed_frames, |
| 30 |
| ) |
| |
| print(f"🔧 Generating OpenX synthetic camera frames: {max_needed_frames}") |
| relative_poses = [] |
| for i in range(max_needed_frames): |
| |
| |
| roll_per_frame = 0.02 |
| pitch_per_frame = 0.01 |
| yaw_per_frame = 0.015 |
| forward_speed = 0.003 |
| |
| pose = np.eye(4, dtype=np.float32) |
| |
| |
| |
| cos_roll = np.cos(roll_per_frame) |
| sin_roll = np.sin(roll_per_frame) |
| |
| cos_pitch = np.cos(pitch_per_frame) |
| sin_pitch = np.sin(pitch_per_frame) |
| |
| cos_yaw = np.cos(yaw_per_frame) |
| sin_yaw = np.sin(yaw_per_frame) |
| |
| |
| pose[0, 0] = cos_yaw * cos_pitch |
| pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll |
| pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll |
| pose[1, 0] = sin_yaw * cos_pitch |
| pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll |
| pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll |
| pose[2, 0] = -sin_pitch |
| pose[2, 1] = cos_pitch * sin_roll |
| pose[2, 2] = cos_pitch * cos_roll |
| |
| |
| pose[0, 3] = forward_speed * 0.5 |
| pose[1, 3] = forward_speed * 0.3 |
| pose[2, 3] = -forward_speed |
| |
| relative_pose = pose[:3, :] |
| relative_poses.append(torch.as_tensor(relative_pose)) |
| |
| pose_embedding = torch.stack(relative_poses, dim=0) |
| pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
| |
| |
| mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
| condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
| mask[start_frame:condition_end] = 1.0 |
| |
| camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
| print(f"🔧 OpenX synthetic camera embedding shape: {camera_embedding.shape}") |
| return camera_embedding.to(torch.bfloat16) |
|
|
|
|
| def generate_nuscenes_camera_embeddings_sliding( |
| scene_info, start_frame, initial_condition_frames, new_frames): |
| """ |
| Generate camera embeddings for NuScenes dataset - sliding window version |
| |
| corrected version, consistent with train_moe.py |
| """ |
| time_compression_ratio = 4 |
| |
| |
| framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames |
| |
| if scene_info is not None and 'keyframe_poses' in scene_info: |
| print("🔧 Using NuScenes real pose data") |
| keyframe_poses = scene_info['keyframe_poses'] |
| |
| if len(keyframe_poses) == 0: |
| print("⚠️ NuScenes keyframe_poses is empty, using zero pose") |
| max_needed_frames = max(framepack_needed_frames, 30) |
| |
| pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32) |
| |
| mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
| condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
| mask[start_frame:condition_end] = 1.0 |
| |
| camera_embedding = torch.cat([pose_sequence, mask], dim=1) |
| print(f"🔧 NuScenes zero pose embedding shape: {camera_embedding.shape}") |
| return camera_embedding.to(torch.bfloat16) |
| |
| |
| reference_pose = keyframe_poses[0] |
| |
| max_needed_frames = max(framepack_needed_frames, 30) |
| |
| pose_vecs = [] |
| for i in range(max_needed_frames): |
| if i < len(keyframe_poses): |
| current_pose = keyframe_poses[i] |
| |
| |
| translation = torch.tensor( |
| np.array(current_pose['translation']) - np.array(reference_pose['translation']), |
| dtype=torch.float32 |
| ) |
| |
| |
| rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32) |
| |
| pose_vec = torch.cat([translation, rotation], dim=0) |
| else: |
| |
| pose_vec = torch.cat([ |
| torch.zeros(3, dtype=torch.float32), |
| torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) |
| ], dim=0) |
| |
| pose_vecs.append(pose_vec) |
| |
| pose_sequence = torch.stack(pose_vecs, dim=0) |
| |
| |
| mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
| condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
| mask[start_frame:condition_end] = 1.0 |
| |
| camera_embedding = torch.cat([pose_sequence, mask], dim=1) |
| print(f"🔧 NuScenes real pose embedding shape: {camera_embedding.shape}") |
| return camera_embedding.to(torch.bfloat16) |
| |
| else: |
| print("🔧 Using NuScenes synthetic pose data") |
| max_needed_frames = max(framepack_needed_frames, 30) |
| |
| |
| pose_vecs = [] |
| for i in range(max_needed_frames): |
| |
| angle = i * 0.04 |
| radius = 15.0 |
| |
| |
| x = radius * np.sin(angle) |
| y = 0.0 |
| z = radius * (1 - np.cos(angle)) |
| |
| translation = torch.tensor([x, y, z], dtype=torch.float32) |
| |
| |
| yaw = angle + np.pi/2 |
| |
| rotation = torch.tensor([ |
| np.cos(yaw/2), |
| 0.0, |
| 0.0, |
| np.sin(yaw/2) |
| ], dtype=torch.float32) |
| |
| pose_vec = torch.cat([translation, rotation], dim=0) |
| pose_vecs.append(pose_vec) |
| |
| pose_sequence = torch.stack(pose_vecs, dim=0) |
| |
| |
| mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
| condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
| mask[start_frame:condition_end] = 1.0 |
| |
| camera_embedding = torch.cat([pose_sequence, mask], dim=1) |
| print(f"🔧 NuScenes synthetic left turn pose embedding shape: {camera_embedding.shape}") |
| return camera_embedding.to(torch.bfloat16) |
|
|
| def prepare_framepack_sliding_window_with_camera_moe( |
| history_latents, |
| target_frames_to_generate, |
| camera_embedding_full, |
| start_frame, |
| modality_type, |
| max_history_frames=49): |
| """FramePack sliding window mechanism - MoE version""" |
| |
| C, T, H, W = history_latents.shape |
| |
| |
| |
| total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate |
| indices = torch.arange(0, total_indices_length) |
| split_sizes = [1, 16, 2, 1, target_frames_to_generate] |
| clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ |
| indices.split(split_sizes, dim=0) |
| clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) |
| |
| |
| if camera_embedding_full.shape[0] < total_indices_length: |
| print(f"⚠️ camera_embedding length insufficient, performing zero padding: current length {camera_embedding_full.shape[0]}, required length {total_indices_length}") |
| shortage = total_indices_length - camera_embedding_full.shape[0] |
| padding = torch.zeros(shortage, camera_embedding_full.shape[1], |
| dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) |
| camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) |
| |
| |
| combined_camera = torch.zeros( |
| total_indices_length, |
| camera_embedding_full.shape[1], |
| dtype=camera_embedding_full.dtype, |
| device=camera_embedding_full.device) |
| |
| |
| history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone() |
| combined_camera[19 - history_slice.shape[0]:19, :] = history_slice |
| |
| |
| target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone() |
| combined_camera[19:19 + target_slice.shape[0], :] = target_slice |
| |
| |
| combined_camera[:, -1] = 0.0 |
| |
| |
| if T > 0: |
| available_frames = min(T, 19) |
| start_pos = 19 - available_frames |
| combined_camera[start_pos:19, -1] = 1.0 |
| |
| print(f"🔧 MoE Camera mask update:") |
| print(f" - History frames: {T}") |
| print(f" - Valid condition frames: {available_frames if T > 0 else 0}") |
| print(f" - Modality type: {modality_type}") |
| |
| |
| clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) |
| |
| if T > 0: |
| available_frames = min(T, 19) |
| start_pos = 19 - available_frames |
| clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] |
| |
| clean_latents_4x = clean_latents_combined[:, 0:16, :, :] |
| clean_latents_2x = clean_latents_combined[:, 16:18, :, :] |
| clean_latents_1x = clean_latents_combined[:, 18:19, :, :] |
| |
| if T > 0: |
| start_latent = history_latents[:, 0:1, :, :] |
| else: |
| start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) |
| |
| clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) |
| |
| return { |
| 'latent_indices': latent_indices, |
| 'clean_latents': clean_latents, |
| 'clean_latents_2x': clean_latents_2x, |
| 'clean_latents_4x': clean_latents_4x, |
| 'clean_latent_indices': clean_latent_indices, |
| 'clean_latent_2x_indices': clean_latent_2x_indices, |
| 'clean_latent_4x_indices': clean_latent_4x_indices, |
| 'camera_embedding': combined_camera, |
| 'modality_type': modality_type, |
| 'current_length': T, |
| 'next_length': T + target_frames_to_generate |
| } |
|
|
| def overlay_controls(frame_img, pose_vec, icons): |
| """ |
| Overlay control icons (WASD and arrows) on frame based on camera pose |
| pose_vec: 12 elements (flattened 3x4 matrix) + mask |
| """ |
| if pose_vec is None or np.all(pose_vec[:12] == 0): |
| return frame_img |
| |
| |
| |
| tx = pose_vec[3] |
| |
| tz = pose_vec[11] |
| |
| |
| |
| r00 = pose_vec[0] |
| r02 = pose_vec[2] |
| yaw = np.arctan2(r02, r00) |
| |
| |
| r12 = pose_vec[6] |
| r22 = pose_vec[10] |
| pitch = np.arctan2(-r12, r22) |
| |
| |
| TRANS_THRESH = 0.01 |
| ROT_THRESH = 0.005 |
| |
| |
| |
| |
| is_forward = tz < -TRANS_THRESH |
| is_backward = tz > TRANS_THRESH |
| is_left = tx < -TRANS_THRESH |
| is_right = tx > TRANS_THRESH |
| |
| |
| |
| is_turn_left = yaw > ROT_THRESH |
| is_turn_right = yaw < -ROT_THRESH |
| |
| |
| is_turn_up = pitch < -ROT_THRESH |
| is_turn_down = pitch > ROT_THRESH |
| |
| W, H = frame_img.size |
| spacing = 60 |
| |
| def paste_icon(name_active, name_inactive, is_active, x, y): |
| name = name_active if is_active else name_inactive |
| if name in icons: |
| icon = icons[name] |
| |
| frame_img.paste(icon, (int(x), int(y)), icon) |
| |
| |
| base_x_right = 100 |
| base_y = H - 100 |
| |
| |
| paste_icon('move_forward.png', 'not_move_forward.png', is_forward, base_x_right, base_y - spacing) |
| |
| paste_icon('move_left.png', 'not_move_left.png', is_left, base_x_right - spacing, base_y) |
| |
| paste_icon('move_backward.png', 'not_move_backward.png', is_backward, base_x_right, base_y) |
| |
| paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y) |
| |
| |
| base_x_left = W - 150 |
| |
| |
| paste_icon('turn_up.png', 'not_turn_up.png', is_turn_up, base_x_left, base_y - spacing) |
| |
| paste_icon('turn_left.png', 'not_turn_left.png', is_turn_left, base_x_left - spacing, base_y) |
| |
| paste_icon('turn_down.png', 'not_turn_down.png', is_turn_down, base_x_left, base_y) |
| |
| paste_icon('turn_right.png', 'not_turn_right.png', is_turn_right, base_x_left + spacing, base_y) |
| |
| return frame_img |
|
|
|
|
| def inference_moe_framepack_sliding_window( |
| condition_pth_path=None, |
| condition_video=None, |
| condition_image=None, |
| dit_path=None, |
| wan_model_path=None, |
| output_path="../examples/output_videos/output_moe_framepack_sliding.mp4", |
| start_frame=0, |
| initial_condition_frames=8, |
| frames_per_generation=4, |
| total_frames_to_generate=32, |
| max_history_frames=49, |
| device="cuda", |
| prompt="A video of a scene shot using a pedestrian's front camera while walking", |
| modality_type="sekai", |
| use_real_poses=True, |
| scene_info_path=None, |
| |
| use_camera_cfg=True, |
| camera_guidance_scale=2.0, |
| text_guidance_scale=1.0, |
| |
| moe_num_experts=4, |
| moe_top_k=2, |
| moe_hidden_dim=None, |
| direction="left", |
| use_gt_prompt=True, |
| add_icons=False |
| ): |
| """ |
| MoE FramePack sliding window video generation - multi-modal support |
| """ |
| |
| dir_path = os.path.dirname(output_path) |
| os.makedirs(dir_path, exist_ok=True) |
| |
| print(f"🔧 Starting MoE FramePack sliding window generation...") |
| print(f" Modality type: {modality_type}") |
| print(f" Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") |
| print(f" Text guidance scale: {text_guidance_scale}") |
| print(f" MoE config: experts={moe_num_experts}, top_k={moe_top_k}") |
| |
| |
| replace_dit_model_in_manager() |
| |
| model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") |
| model_manager.load_models([ |
| os.path.join(wan_model_path, "diffusion_pytorch_model.safetensors"), |
| os.path.join(wan_model_path, "models_t5_umt5-xxl-enc-bf16.pth"), |
| os.path.join(wan_model_path, "Wan2.1_VAE.pth"), |
| ]) |
| pipe = WanVideoAstraPipeline.from_model_manager(model_manager, device="cuda") |
|
|
| |
| dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] |
| for block in pipe.dit.blocks: |
| block.cam_encoder = nn.Linear(13, dim) |
| block.projector = nn.Linear(dim, dim) |
| block.cam_encoder.weight.data.zero_() |
| block.cam_encoder.bias.data.zero_() |
| block.projector.weight = nn.Parameter(torch.eye(dim)) |
| block.projector.bias = nn.Parameter(torch.zeros(dim)) |
| |
| |
| add_framepack_components(pipe.dit) |
| |
| |
| moe_config = { |
| "num_experts": moe_num_experts, |
| "top_k": moe_top_k, |
| "hidden_dim": moe_hidden_dim or dim * 2, |
| "sekai_input_dim": 13, |
| "nuscenes_input_dim": 8, |
| "openx_input_dim": 13 |
| } |
| add_moe_components(pipe.dit, moe_config) |
| |
| |
| dit_state_dict = torch.load(dit_path, map_location="cpu") |
| pipe.dit.load_state_dict(dit_state_dict, strict=False) |
| pipe = pipe.to(device) |
| model_dtype = next(pipe.dit.parameters()).dtype |
| |
| if hasattr(pipe.dit, 'clean_x_embedder'): |
| pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) |
| |
| |
| pipe.scheduler.set_timesteps(50) |
| |
| |
| print("Loading initial condition frames...") |
| initial_latents, encoded_data = load_or_encode_condition( |
| condition_pth_path, |
| condition_video, |
| condition_image, |
| start_frame, |
| initial_condition_frames, |
| device, |
| pipe, |
| ) |
| |
| |
| target_height, target_width = 60, 104 |
| C, T, H, W = initial_latents.shape |
| |
| if H > target_height or W > target_width: |
| h_start = (H - target_height) // 2 |
| w_start = (W - target_width) // 2 |
| initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] |
| H, W = target_height, target_width |
| |
| history_latents = initial_latents.to(device, dtype=model_dtype) |
|
|
| print(f"Initial history_latents shape: {history_latents.shape}") |
| |
| |
| if use_gt_prompt and 'prompt_emb' in encoded_data: |
| print("✅ Using pre-encoded GT prompt embedding") |
| prompt_emb_pos = encoded_data['prompt_emb'] |
| |
| if 'context' in prompt_emb_pos: |
| prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype) |
| if 'context_mask' in prompt_emb_pos: |
| prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype) |
| |
| |
| if text_guidance_scale > 1.0: |
| prompt_emb_neg = pipe.encode_prompt("") |
| print(f"Using Text CFG with GT prompt, guidance scale: {text_guidance_scale}") |
| else: |
| prompt_emb_neg = None |
| print("Not using Text CFG") |
| |
| |
| if 'prompt' in encoded_data['prompt_emb']: |
| gt_prompt_text = encoded_data['prompt_emb']['prompt'] |
| print(f"📝 GT Prompt text: {gt_prompt_text}") |
| else: |
| |
| print(f"🔄 Re-encoding prompt: {prompt}") |
| if text_guidance_scale > 1.0: |
| prompt_emb_pos = pipe.encode_prompt(prompt) |
| prompt_emb_neg = pipe.encode_prompt("") |
| print(f"Using Text CFG, guidance scale: {text_guidance_scale}") |
| else: |
| prompt_emb_pos = pipe.encode_prompt(prompt) |
| prompt_emb_neg = None |
| print("Not using Text CFG") |
| |
| |
| scene_info = None |
| if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path): |
| with open(scene_info_path, 'r') as f: |
| scene_info = json.load(f) |
| print(f"Loading NuScenes scene information: {scene_info_path}") |
| |
| |
| if modality_type == "sekai": |
| camera_embedding_full = generate_sekai_camera_embeddings_sliding( |
| encoded_data.get('cam_emb', None), |
| start_frame, |
| initial_condition_frames, |
| total_frames_to_generate, |
| 0, |
| use_real_poses=use_real_poses, |
| direction=direction |
| ).to(device, dtype=model_dtype) |
| elif modality_type == "nuscenes": |
| camera_embedding_full = generate_nuscenes_camera_embeddings_sliding( |
| scene_info, |
| start_frame, |
| initial_condition_frames, |
| total_frames_to_generate |
| ).to(device, dtype=model_dtype) |
| elif modality_type == "openx": |
| camera_embedding_full = generate_openx_camera_embeddings_sliding( |
| encoded_data, |
| start_frame, |
| initial_condition_frames, |
| total_frames_to_generate, |
| use_real_poses=use_real_poses |
| ).to(device, dtype=model_dtype) |
| else: |
| raise ValueError(f"Unsupported modality type: {modality_type}") |
| |
| print(f"Complete camera sequence shape: {camera_embedding_full.shape}") |
| |
| |
| if use_camera_cfg: |
| camera_embedding_uncond = torch.zeros_like(camera_embedding_full) |
| print(f"Creating unconditional camera embedding for CFG") |
| |
| |
| total_generated = 0 |
| all_generated_frames = [] |
| |
| while total_generated < total_frames_to_generate: |
| current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) |
| print(f"\nGeneration step {total_generated // frames_per_generation + 1}") |
| print(f"Current history length: {history_latents.shape[1]}, generating: {current_generation}") |
| |
| |
| framepack_data = prepare_framepack_sliding_window_with_camera_moe( |
| history_latents, |
| current_generation, |
| camera_embedding_full, |
| start_frame, |
| modality_type, |
| max_history_frames |
| ) |
| |
| |
| clean_latents = framepack_data['clean_latents'].unsqueeze(0) |
| clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) |
| clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) |
| camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) |
| |
| |
| modality_inputs = {modality_type: camera_embedding} |
| |
| |
| if use_camera_cfg: |
| camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) |
| modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch} |
| |
| |
| latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() |
| clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() |
| clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() |
| clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() |
| |
| |
| new_latents = torch.randn( |
| 1, C, current_generation, H, W, |
| device=device, dtype=model_dtype |
| ) |
| |
| extra_input = pipe.prepare_extra_input(new_latents) |
| |
| print(f"Camera embedding shape: {camera_embedding.shape}") |
| print(f"Camera mask distribution - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") |
| |
| |
| timesteps = pipe.scheduler.timesteps |
| |
| for i, timestep in enumerate(timesteps): |
| if i % 10 == 0: |
| print(f" Denoising step {i+1}/{len(timesteps)}") |
| |
| timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) |
| |
| with torch.no_grad(): |
| |
| if use_camera_cfg and camera_guidance_scale > 1.0: |
| |
| noise_pred_cond, moe_loess = pipe.dit( |
| new_latents, |
| timestep=timestep_tensor, |
| cam_emb=camera_embedding, |
| modality_inputs=modality_inputs, |
| latent_indices=latent_indices, |
| clean_latents=clean_latents, |
| clean_latent_indices=clean_latent_indices, |
| clean_latents_2x=clean_latents_2x, |
| clean_latent_2x_indices=clean_latent_2x_indices, |
| clean_latents_4x=clean_latents_4x, |
| clean_latent_4x_indices=clean_latent_4x_indices, |
| **prompt_emb_pos, |
| **extra_input |
| ) |
| |
| |
| noise_pred_uncond, moe_loess = pipe.dit( |
| new_latents, |
| timestep=timestep_tensor, |
| cam_emb=camera_embedding_uncond_batch, |
| modality_inputs=modality_inputs_uncond, |
| latent_indices=latent_indices, |
| clean_latents=clean_latents, |
| clean_latent_indices=clean_latent_indices, |
| clean_latents_2x=clean_latents_2x, |
| clean_latent_2x_indices=clean_latent_2x_indices, |
| clean_latents_4x=clean_latents_4x, |
| clean_latent_4x_indices=clean_latent_4x_indices, |
| **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), |
| **extra_input |
| ) |
| |
| |
| noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) |
| |
| |
| if text_guidance_scale > 1.0 and prompt_emb_neg: |
| noise_pred_text_uncond, moe_loess = pipe.dit( |
| new_latents, |
| timestep=timestep_tensor, |
| cam_emb=camera_embedding, |
| modality_inputs=modality_inputs, |
| latent_indices=latent_indices, |
| clean_latents=clean_latents, |
| clean_latent_indices=clean_latent_indices, |
| clean_latents_2x=clean_latents_2x, |
| clean_latent_2x_indices=clean_latent_2x_indices, |
| clean_latents_4x=clean_latents_4x, |
| clean_latent_4x_indices=clean_latent_4x_indices, |
| **prompt_emb_neg, |
| **extra_input |
| ) |
| |
| |
| noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) |
| |
| elif text_guidance_scale > 1.0 and prompt_emb_neg: |
| |
| noise_pred_cond, moe_loess = pipe.dit( |
| new_latents, |
| timestep=timestep_tensor, |
| cam_emb=camera_embedding, |
| modality_inputs=modality_inputs, |
| latent_indices=latent_indices, |
| clean_latents=clean_latents, |
| clean_latent_indices=clean_latent_indices, |
| clean_latents_2x=clean_latents_2x, |
| clean_latent_2x_indices=clean_latent_2x_indices, |
| clean_latents_4x=clean_latents_4x, |
| clean_latent_4x_indices=clean_latent_4x_indices, |
| **prompt_emb_pos, |
| **extra_input |
| ) |
| |
| noise_pred_uncond, moe_loess= pipe.dit( |
| new_latents, |
| timestep=timestep_tensor, |
| cam_emb=camera_embedding, |
| modality_inputs=modality_inputs, |
| latent_indices=latent_indices, |
| clean_latents=clean_latents, |
| clean_latent_indices=clean_latent_indices, |
| clean_latents_2x=clean_latents_2x, |
| clean_latent_2x_indices=clean_latent_2x_indices, |
| clean_latents_4x=clean_latents_4x, |
| clean_latent_4x_indices=clean_latent_4x_indices, |
| **prompt_emb_neg, |
| **extra_input |
| ) |
| |
| noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) |
| |
| else: |
| |
| noise_pred, moe_loess = pipe.dit( |
| new_latents, |
| timestep=timestep_tensor, |
| cam_emb=camera_embedding, |
| modality_inputs=modality_inputs, |
| latent_indices=latent_indices, |
| clean_latents=clean_latents, |
| clean_latent_indices=clean_latent_indices, |
| clean_latents_2x=clean_latents_2x, |
| clean_latent_2x_indices=clean_latent_2x_indices, |
| clean_latents_4x=clean_latents_4x, |
| clean_latent_4x_indices=clean_latent_4x_indices, |
| **prompt_emb_pos, |
| **extra_input |
| ) |
| |
| new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) |
| |
| |
| new_latents_squeezed = new_latents.squeeze(0) |
| history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) |
| |
| |
| if history_latents.shape[1] > max_history_frames: |
| first_frame = history_latents[:, 0:1, :, :] |
| recent_frames = history_latents[:, -(max_history_frames-1):, :, :] |
| history_latents = torch.cat([first_frame, recent_frames], dim=1) |
| print(f"⚠️ History window full, keeping first frame + latest {max_history_frames-1} frames") |
| |
| print(f"History_latents shape after update: {history_latents.shape}") |
| |
| all_generated_frames.append(new_latents_squeezed) |
| total_generated += current_generation |
| |
| print(f"✅ Generated {total_generated}/{total_frames_to_generate} frames") |
| |
| |
| print("\nDecoding generated video...") |
| |
| all_generated = torch.cat(all_generated_frames, dim=1) |
| final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) |
| |
| print(f"Final video shape: {final_video.shape}") |
| |
| decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) |
| |
| print(f"Saving video to {output_path} ...") |
| |
| video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() |
| video_np = (video_np * 0.5 + 0.5).clip(0, 1) |
| video_np = (video_np * 255).astype(np.uint8) |
|
|
| icons = {} |
| video_camera_poses = None |
| if add_icons: |
| |
| icons_dir = os.path.join(ROOT_DIR, 'icons') |
| icon_names = ['move_forward.png', 'not_move_forward.png', |
| 'move_backward.png', 'not_move_backward.png', |
| 'move_left.png', 'not_move_left.png', |
| 'move_right.png', 'not_move_right.png', |
| 'turn_up.png', 'not_turn_up.png', |
| 'turn_down.png', 'not_turn_down.png', |
| 'turn_left.png', 'not_turn_left.png', |
| 'turn_right.png', 'not_turn_right.png'] |
| for name in icon_names: |
| path = os.path.join(icons_dir, name) |
| if os.path.exists(path): |
| try: |
| icon = Image.open(path).convert("RGBA") |
| |
| icon = icon.resize((50, 50), Image.Resampling.LANCZOS) |
| icons[name] = icon |
| except Exception as e: |
| print(f"Error loading icon {name}: {e}") |
| else: |
| print(f"⚠️ Warning: Icon {name} not found at {path}") |
|
|
| |
| time_compression_ratio = 4 |
| camera_poses = camera_embedding_full.detach().float().cpu().numpy() |
| video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)] |
|
|
| with imageio.get_writer(output_path, fps=20) as writer: |
| for i, frame in enumerate(video_np): |
| |
| img = Image.fromarray(frame) |
| |
| if add_icons and video_camera_poses is not None and icons: |
| |
| pose_idx = start_frame + i |
| if pose_idx < len(video_camera_poses): |
| pose_vec = video_camera_poses[pose_idx] |
| img = overlay_controls(img, pose_vec, icons) |
| |
| writer.append_data(np.array(img)) |
|
|
| print(f"✅ MoE FramePack sliding window generation completed! Saved to: {output_path}") |
| print(f" Total generated {total_generated} frames (compressed), corresponding to original {total_generated * 4} frames") |
| print(f" Using modality: {modality_type}") |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MoE FramePack sliding window video generation - supports multi-modal") |
| |
| |
| parser.add_argument("--condition_pth", |
| type=str, |
| default=None, |
| help="Path to pre-encoded condition pth file") |
| parser.add_argument("--condition_video", |
| type=str, |
| default=None, |
| help="Input video for novel view synthesis.") |
| parser.add_argument("--condition_image", |
| type=str, |
| default=None, |
| required=True, |
| help="Input image for novel view synthesis.") |
| parser.add_argument("--start_frame", type=int, default=0) |
| parser.add_argument("--initial_condition_frames", type=int, default=1) |
| parser.add_argument("--frames_per_generation", type=int, default=8) |
| parser.add_argument("--total_frames_to_generate", type=int, default=24) |
| parser.add_argument("--max_history_frames", type=int, default=100) |
| parser.add_argument("--use_real_poses", default=False) |
| parser.add_argument("--dit_path", type=str, |
| default="../models/Astra/checkpoints/diffusion_pytorch_model.ckpt", |
| help="path to the pretrained DiT MoE model checkpoint") |
| parser.add_argument("--wan_model_path", |
| type=str, |
| default="../models/Wan-AI/Wan2.1-T2V-1.3B", |
| help="path to Wan2.1-T2V-1.3B") |
| parser.add_argument("--output_path", type=str, |
| default='../examples/output_videos/output_moe_framepack_sliding.mp4') |
| parser.add_argument("--prompt", |
| type=str, |
| default="", |
| help="text prompt for video generation") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--add_icons", action="store_true", default=False, |
| help="Overlay control icons on generated video") |
| |
| |
| parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], |
| default="sekai", help="Modality type: sekai, nuscenes, or openx") |
| parser.add_argument("--scene_info_path", type=str, default=None, |
| help="NuScenes scene info file path (for nuscenes modality only)") |
| |
| |
| parser.add_argument("--use_camera_cfg", default=False, |
| help="Use Camera CFG") |
| parser.add_argument("--camera_guidance_scale", type=float, default=2.0, |
| help="Camera guidance scale for CFG") |
| parser.add_argument("--text_guidance_scale", type=float, default=1.0, |
| help="Text guidance scale for CFG") |
| |
| |
| parser.add_argument("--moe_num_experts", type=int, default=3, help="Number of experts") |
| parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K experts") |
| parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE hidden dimension") |
| parser.add_argument("--direction", type=str, default="left", help="Direction of video trajectory") |
| parser.add_argument("--use_gt_prompt", action="store_true", default=False, |
| help="Use ground truth prompt embedding from dataset") |
| |
| args = parser.parse_args() |
|
|
| print(f"MoE FramePack CFG generation settings:") |
| print(f"Modality type: {args.modality_type}") |
| print(f"Camera CFG: {args.use_camera_cfg}") |
| if args.use_camera_cfg: |
| print(f"Camera guidance scale: {args.camera_guidance_scale}") |
| print(f"Using GT Prompt: {args.use_gt_prompt}") |
| print(f"Text guidance scale: {args.text_guidance_scale}") |
| print(f"MoE config: experts={args.moe_num_experts}, top_k={args.moe_top_k}") |
| print(f"DiT{args.dit_path}") |
| |
| |
| if args.modality_type == "nuscenes" and not args.scene_info_path: |
| print("⚠️ Warning: Using NuScenes modality but scene_info_path not provided, will use synthetic pose data") |
| |
| if not args.use_gt_prompt and (args.prompt is None or args.prompt.strip() == ""): |
| print("⚠️ Warning: No prompt provided, will use empty string as prompt") |
| |
| if not any([args.condition_pth, args.condition_video, args.condition_image]): |
| raise ValueError("Need to provide condition_pth, condition_video, or condition_image as condition input") |
| |
| if args.condition_pth: |
| print(f"Using pre-encoded pth: {args.condition_pth}") |
| elif args.condition_video: |
| print(f"Using condition video for online encoding: {args.condition_video}") |
| elif args.condition_image: |
| print(f"Using condition image for online encoding: {args.condition_image} (repeat 10 frames)") |
| |
| inference_moe_framepack_sliding_window( |
| condition_pth_path=args.condition_pth, |
| condition_video=args.condition_video, |
| condition_image=args.condition_image, |
| dit_path=args.dit_path, |
| wan_model_path=args.wan_model_path, |
| output_path=args.output_path, |
| start_frame=args.start_frame, |
| initial_condition_frames=args.initial_condition_frames, |
| frames_per_generation=args.frames_per_generation, |
| total_frames_to_generate=args.total_frames_to_generate, |
| max_history_frames=args.max_history_frames, |
| device=args.device, |
| prompt=args.prompt, |
| modality_type=args.modality_type, |
| use_real_poses=args.use_real_poses, |
| scene_info_path=args.scene_info_path, |
| |
| use_camera_cfg=args.use_camera_cfg, |
| camera_guidance_scale=args.camera_guidance_scale, |
| text_guidance_scale=args.text_guidance_scale, |
| |
| moe_num_experts=args.moe_num_experts, |
| moe_top_k=args.moe_top_k, |
| moe_hidden_dim=args.moe_hidden_dim, |
| direction=args.direction, |
| use_gt_prompt=args.use_gt_prompt, |
| add_icons=args.add_icons |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |