| import numpy as np |
| import os |
| import torch |
| from einops import rearrange |
|
|
| script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
| class Camera(object): |
| """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
| """ |
| def __init__(self, entry): |
| fx, fy, cx, cy = entry[1:5] |
| self.fx = fx |
| self.fy = fy |
| self.cx = cx |
| self.cy = cy |
| w2c_mat = np.array(entry[7:]).reshape(3, 4) |
| w2c_mat_4x4 = np.eye(4) |
| w2c_mat_4x4[:3, :] = w2c_mat |
| self.w2c_mat = w2c_mat_4x4 |
| self.c2w_mat = np.linalg.inv(w2c_mat_4x4) |
|
|
| def custom_meshgrid(*args): |
| """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
| """ |
| |
| return torch.meshgrid(*args) |
| |
|
|
| def get_relative_pose(cam_params): |
| """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
| """ |
| abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] |
| abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] |
| cam_to_origin = 0 |
| target_cam_c2w = np.array([ |
| [1, 0, 0, 0], |
| [0, 1, 0, -cam_to_origin], |
| [0, 0, 1, 0], |
| [0, 0, 0, 1] |
| ]) |
| abs2rel = target_cam_c2w @ abs_w2cs[0] |
| ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] |
| ret_poses = np.array(ret_poses, dtype=np.float32) |
| return ret_poses |
|
|
| def ray_condition(K, c2w, H, W, device): |
| """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
| """ |
| |
| |
|
|
| B = K.shape[0] |
|
|
| j, i = custom_meshgrid( |
| torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), |
| torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), |
| ) |
| i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 |
| j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 |
|
|
| fx, fy, cx, cy = K.chunk(4, dim=-1) |
|
|
| zs = torch.ones_like(i) |
| xs = (i - cx) / fx * zs |
| ys = (j - cy) / fy * zs |
| zs = zs.expand_as(ys) |
|
|
| directions = torch.stack((xs, ys, zs), dim=-1) |
| directions = directions / directions.norm(dim=-1, keepdim=True) |
|
|
| rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) |
| rays_o = c2w[..., :3, 3] |
| rays_o = rays_o[:, :, None].expand_as(rays_d) |
| |
| rays_dxo = torch.cross(rays_o, rays_d) |
| plucker = torch.cat([rays_dxo, rays_d], dim=-1) |
| plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) |
| |
| return plucker |
|
|
| def process_poses(poses, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): |
| """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py |
| """ |
| |
| cam_params = [[float(x) for x in pose] for pose in poses] |
| if return_poses: |
| return cam_params |
| else: |
| cam_params = [Camera(cam_param) for cam_param in cam_params] |
|
|
| sample_wh_ratio = width / height |
| pose_wh_ratio = original_pose_width / original_pose_height |
|
|
| if pose_wh_ratio > sample_wh_ratio: |
| resized_ori_w = height * pose_wh_ratio |
| for cam_param in cam_params: |
| cam_param.fx = resized_ori_w * cam_param.fx / width |
| else: |
| resized_ori_h = width / pose_wh_ratio |
| for cam_param in cam_params: |
| cam_param.fy = resized_ori_h * cam_param.fy / height |
|
|
| intrinsic = np.asarray([[cam_param.fx * width, |
| cam_param.fy * height, |
| cam_param.cx * width, |
| cam_param.cy * height] |
| for cam_param in cam_params], dtype=np.float32) |
|
|
| K = torch.as_tensor(intrinsic)[None] |
| c2ws = get_relative_pose(cam_params) |
| c2ws = torch.as_tensor(c2ws)[None] |
| plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() |
| plucker_embedding = plucker_embedding[None] |
| plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] |
| return plucker_embedding |
|
|
| class WanVideoFunCameraEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "poses": ("CAMERACTRL_POSES", ), |
| "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), |
| "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Strength of the camera motion"}), |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply camera motion"}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply camera motion"}), |
| }, |
| |
| |
| |
| } |
| |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, poses, width, height, strength, start_percent, end_percent, fun_ref_image=None): |
| num_frames = len(poses) |
|
|
| control_camera_video = process_poses(poses, width, height) |
| control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0) |
| print("control_camera_video.shape", control_camera_video.shape) |
| |
| |
| |
| control_camera_latents = torch.concat( |
| [ |
| torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), |
| control_camera_video[:, :, 1:] |
| ], dim=2 |
| ).transpose(1, 2) |
|
|
| |
| b, f, c, h, w = control_camera_latents.shape |
| control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) |
| control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) |
| print("control_camera_latents.shape", control_camera_latents.shape) |
|
|
| vae_stride = (4, 8, 8) |
|
|
| target_shape = (16, (num_frames - 1) // vae_stride[0] + 1, |
| height // vae_stride[1], |
| width // vae_stride[2]) |
| |
| embeds = { |
| "target_shape": target_shape, |
| "num_frames": num_frames, |
| "control_embeds": { |
| "control_camera_latents": control_camera_latents * strength, |
| "control_camera_start_percent": start_percent, |
| "control_camera_end_percent": end_percent, |
| "fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None, |
| } |
| } |
| |
| return (embeds,) |
|
|
| NODE_CLASS_MAPPINGS = { |
| "WanVideoFunCameraEmbeds": WanVideoFunCameraEmbeds, |
| } |
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "WanVideoFunCameraEmbeds": "WanVideo FunCamera Embeds", |
| } |
|
|