| from typing import Any, Callable, Dict, List, Optional, Union |
| import PIL.Image |
| import torch |
| import math |
| import random |
| import numpy as np |
| import torch.nn.functional as F |
| from typing import Tuple |
| from PIL import Image |
|
|
| from vae import WanVAE |
| from vace.models.wan.modules.model_mm import VaceMMModel |
| from vace.models.wan.modules.model_tr import VaceWanModel |
|
|
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
| from diffusers.image_processor import PipelineImageInput |
| from diffusers.loaders import WanLoraLoaderMixin |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
| from diffusers.utils import logging |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.video_processor import VideoProcessor |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| from diffusers.utils import BaseOutput |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class RefacadePipelineOutput(BaseOutput): |
| frames: torch.Tensor |
| meshes: torch.Tensor |
| ref_img: torch.Tensor |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @torch.no_grad() |
| def _pad_to_multiple(x: torch.Tensor, multiple: int, mode: str = "reflect"): |
| H, W = x.shape[-2], x.shape[-1] |
| pad_h = (multiple - H % multiple) % multiple |
| pad_w = (multiple - W % multiple) % multiple |
| pad = (0, pad_w, 0, pad_h) |
| if pad_h or pad_w: |
| x = F.pad(x, pad, mode=mode) |
| return x, pad |
|
|
|
|
| @torch.no_grad() |
| def _unpad(x: torch.Tensor, pad): |
| l, r, t, b = pad |
| H, W = x.shape[-2], x.shape[-1] |
| return x[..., t:H - b if b > 0 else H, l:W - r if r > 0 else W] |
|
|
|
|
| @torch.no_grad() |
| def _resize(x: torch.Tensor, size: tuple, is_mask: bool): |
| mode = "nearest" if is_mask else "bilinear" |
| if is_mask: |
| return F.interpolate(x, size=size, mode=mode) |
| else: |
| return F.interpolate(x, size=size, mode=mode, align_corners=False) |
|
|
|
|
| @torch.no_grad() |
| def _center_scale_foreground_to_canvas( |
| x_f: torch.Tensor, |
| m_f: torch.Tensor, |
| target_hw: tuple, |
| bg_value: float = 1.0, |
| ): |
| C, H, W = x_f.shape |
| H2, W2 = target_hw |
| device = x_f.device |
| ys, xs = (m_f > 0.5).nonzero(as_tuple=True) |
| canvas = torch.full((C, H2, W2), bg_value, dtype=x_f.dtype, device=device) |
| mask_canvas = torch.zeros((1, H2, W2), dtype=x_f.dtype, device=device) |
| if ys.numel() == 0: |
| return canvas, mask_canvas |
|
|
| y0, y1 = ys.min().item(), ys.max().item() |
| x0, x1 = xs.min().item(), xs.max().item() |
| crop_img = x_f[:, y0:y1 + 1, x0:x1 + 1] |
| crop_msk = m_f[y0:y1 + 1, x0:x1 + 1].unsqueeze(0) |
| hc, wc = crop_msk.shape[-2], crop_msk.shape[-1] |
| s = min(H2 / max(1, hc), W2 / max(1, wc)) |
| Ht = max(1, min(H2, int(math.floor(hc * s)))) |
| Wt = max(1, min(W2, int(math.floor(wc * s)))) |
| crop_img_up = _resize(crop_img.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0) |
| crop_msk_up = _resize(crop_msk.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0) |
| crop_msk_up = (crop_msk_up > 0.5).to(crop_msk_up.dtype) |
|
|
| top = (H2 - Ht) // 2 |
| left = (W2 - Wt) // 2 |
| canvas[:, top:top + Ht, left:left + Wt] = crop_img_up |
| mask_canvas[:, top:top + Ht, left:left + Wt] = crop_msk_up |
| return canvas, mask_canvas |
|
|
|
|
| @torch.no_grad() |
| def _sample_patch_size_from_hw( |
| H: int, |
| W: int, |
| ratio: float = 0.2, |
| min_px: int = 16, |
| max_px: Optional[int] = None, |
| ) -> int: |
| r = ratio |
| raw = r * min(H, W) |
| if max_px is None: |
| max_px = min(192, min(H, W)) |
| P = int(round(raw)) |
| P = max(min_px, min(P, max_px)) |
| P = int(P) |
| return P |
|
|
|
|
| @torch.no_grad() |
| def _masked_patch_pack_to_center_rectangle( |
| x_f: torch.Tensor, |
| m_f: torch.Tensor, |
| patch: int, |
| fg_thresh: float = 0.8, |
| bg_value: float = 1.0, |
| min_patches: int = 4, |
| flip_prob: float = 0.5, |
| use_morph_erode: bool = False, |
| ): |
|
|
| C, H, W = x_f.shape |
| device = x_f.device |
| P = int(patch) |
|
|
| x_pad, pad = _pad_to_multiple(x_f, P, mode="reflect") |
| l, r, t, b = pad |
| H2, W2 = x_pad.shape[-2], x_pad.shape[-1] |
| m_pad = F.pad(m_f.unsqueeze(0).unsqueeze(0), (l, r, t, b), mode="constant", value=0.0).squeeze(0) |
|
|
| cs_img, cs_msk = _center_scale_foreground_to_canvas(x_pad, m_pad.squeeze(0), (H2, W2), bg_value) |
| if (cs_msk > 0.5).sum() == 0: |
| out_img = _unpad(cs_img, pad).clamp_(-1, 1) |
| out_msk = _unpad(cs_msk, pad).clamp_(0, 1) |
| return out_img, out_msk, True |
|
|
| m_eff = cs_msk |
| if use_morph_erode: |
| erode_px = int(max(1, min(6, round(P * 0.03)))) |
| m_eff = 1.0 - F.max_pool2d(1.0 - cs_msk, kernel_size=2 * erode_px + 1, stride=1, padding=erode_px) |
|
|
| x_pad2, pad2 = _pad_to_multiple(cs_img, P, mode="reflect") |
| m_pad2 = F.pad(m_eff, pad2, mode="constant", value=0.0) |
| H3, W3 = x_pad2.shape[-2], x_pad2.shape[-1] |
|
|
| m_pool = F.avg_pool2d(m_pad2, kernel_size=P, stride=P).view(-1) |
|
|
| base_thr = float(fg_thresh) |
| thr_candidates = [base_thr, max(base_thr - 0.05, 0.75), max(base_thr - 0.10, 0.60)] |
|
|
| x_unf = F.unfold(x_pad2.unsqueeze(0), kernel_size=P, stride=P) |
| N = x_unf.shape[-1] |
|
|
| sel = None |
| for thr in thr_candidates: |
| idx = (m_pool >= (thr - 1e-6)).nonzero(as_tuple=False).squeeze(1) |
| if idx.numel() >= min_patches: |
| sel = idx |
| break |
| if sel is None: |
| img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) |
| msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) |
| return img_fallback, msk_fallback, True |
|
|
| sel = sel.to(device=device, dtype=torch.long) |
| sel = sel[(sel >= 0) & (sel < N)] |
| if sel.numel() == 0: |
| img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) |
| msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) |
| return img_fallback, msk_fallback, True |
|
|
| perm = torch.randperm(sel.numel(), device=device, dtype=torch.long) |
| sel = sel[perm] |
| chosen_x = x_unf[:, :, sel] |
| K = chosen_x.shape[-1] |
| if K == 0: |
| img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) |
| msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) |
| return img_fallback, msk_fallback, True |
|
|
| if flip_prob > 0: |
| cx4 = chosen_x.view(1, C, P, P, K) |
| do_flip = (torch.rand(K, device=device) < flip_prob) |
| coin = (torch.rand(K, device=device) < 0.5) |
| flip_h = do_flip & coin |
| flip_v = do_flip & (~coin) |
| if flip_h.any(): |
| cx4[..., flip_h] = cx4[..., flip_h].flip(dims=[3]) |
| if flip_v.any(): |
| cx4[..., flip_v] = cx4[..., flip_v].flip(dims=[2]) |
| chosen_x = cx4.view(1, C * P * P, K) |
|
|
| max_cols = max(1, W3 // P) |
| max_rows = max(1, H3 // P) |
| capacity = max_rows * max_cols |
| K_cap = min(K, capacity) |
| cols = int(max(1, min(int(math.floor(math.sqrt(K_cap))), max_cols))) |
| rows_full = min(max_rows, K_cap // cols) |
| K_used = rows_full * cols |
| if K_used == 0: |
| img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) |
| msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) |
| return img_fallback, msk_fallback, True |
|
|
| chosen_x = chosen_x[:, :, :K_used] |
| rect_unf = torch.full((1, C * P * P, rows_full * cols), bg_value, device=device, dtype=x_f.dtype) |
| rect_unf[:, :, :K_used] = chosen_x |
| rect = F.fold(rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0) |
|
|
| ones_patch = torch.ones((1, 1 * P * P, K_used), device=device, dtype=x_f.dtype) |
| mask_rect_unf = torch.zeros((1, 1 * P * P, rows_full * cols), device=device, dtype=x_f.dtype) |
| mask_rect_unf[:, :, :K_used] = ones_patch |
| rect_mask = F.fold(mask_rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0) |
|
|
| Hr, Wr = rect.shape[-2], rect.shape[-1] |
| s = min(H3 / max(1, Hr), W3 / max(1, Wr)) |
| Ht = min(max(1, int(math.floor(Hr * s))), H3) |
| Wt = min(max(1, int(math.floor(Wr * s))), W3) |
|
|
| rect_up = _resize(rect.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0) |
| rect_mask_up = _resize(rect_mask.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0) |
|
|
| canvas_x = torch.full((C, H3, W3), bg_value, device=device, dtype=x_f.dtype) |
| canvas_m = torch.zeros((1, H3, W3), device=device, dtype=x_f.dtype) |
| top, left = (H3 - Ht) // 2, (W3 - Wt) // 2 |
| canvas_x[:, top:top + Ht, left:left + Wt] = rect_up |
| canvas_m[:, top:top + Ht, left:left + Wt] = rect_mask_up |
|
|
| out_img = _unpad(_unpad(canvas_x, pad2), pad).clamp_(-1, 1) |
| out_msk = _unpad(_unpad(canvas_m, pad2), pad).clamp_(0, 1) |
| return out_img, out_msk, False |
|
|
|
|
| @torch.no_grad() |
| def _compose_centered_foreground(x_f: torch.Tensor, m_f3: torch.Tensor, target_hw: Tuple[int, int], bg_value: float = 1.0): |
| m_bin = (m_f3 > 0.5).float().mean(dim=0) |
| m_bin = (m_bin > 0.5).float() |
| return _center_scale_foreground_to_canvas(x_f, m_bin, target_hw, bg_value) |
|
|
| class RefacadePipeline(DiffusionPipeline, WanLoraLoaderMixin): |
|
|
| model_cpu_offload_seq = "texture_remover->transformer->vae" |
|
|
| def __init__( |
| self, |
| vae, |
| scheduler: FlowMatchEulerDiscreteScheduler, |
| transformer: VaceMMModel = None, |
| texture_remover: VaceWanModel = None, |
| ): |
| super().__init__() |
|
|
| self.register_modules( |
| vae=vae, |
| texture_remover=texture_remover, |
| transformer=transformer, |
| scheduler=scheduler, |
| ) |
| self.vae_scale_factor_temporal = 4 |
| self.vae_scale_factor_spatial = 8 |
| self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) |
| self.empty_embedding = torch.load( |
| "./text_embedding/empty.pt", |
| map_location="cpu" |
| ) |
| self.negative_embedding = torch.load( |
| "./text_embedding/negative.pt", |
| map_location="cpu" |
| ) |
|
|
| def vace_encode_masks(self, masks: torch.Tensor): |
| masks = masks[:, :1, :, :, :] |
| B, C, D, H, W = masks.shape |
| patch_h, patch_w = self.vae_scale_factor_spatial, self.vae_scale_factor_spatial |
| stride_t = self.vae_scale_factor_temporal |
| patch_count = patch_h * patch_w |
| new_D = (D + stride_t - 1) // stride_t |
| new_H = 2 * (H // (patch_h * 2)) |
| new_W = 2 * (W // (patch_w * 2)) |
| masks = masks[:, 0] |
| masks = masks.view(B, D, new_H, patch_h, new_W, patch_w) |
| masks = masks.permute(0, 3, 5, 1, 2, 4) |
| masks = masks.reshape(B, patch_count, D, new_H, new_W) |
| masks = F.interpolate( |
| masks, |
| size=(new_D, new_H, new_W), |
| mode="nearest-exact" |
| ) |
| return masks |
|
|
| def preprocess_conditions( |
| self, |
| video: Optional[List[PipelineImageInput]] = None, |
| mask: Optional[List[PipelineImageInput]] = None, |
| reference_image: Optional[PIL.Image.Image] = None, |
| reference_mask: Optional[PIL.Image.Image] = None, |
| batch_size: int = 1, |
| height: int = 480, |
| width: int = 832, |
| num_frames: int = 81, |
| reference_patch_ratio: float = 0.2, |
| fg_thresh: float = 0.9, |
| dtype: Optional[torch.dtype] = None, |
| device: Optional[torch.device] = None, |
| ): |
|
|
| base = self.vae_scale_factor_spatial * 2 |
| video_height, video_width = self.video_processor.get_default_height_width(video[0]) |
| |
| if video_height * video_width > height * width: |
| scale_w = width / video_width |
| scale_h = height / video_height |
| video_height, video_width = int(video_height * scale_h), int(video_width * scale_w) |
|
|
| if video_height % base != 0 or video_width % base != 0: |
| logger.warning( |
| f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. " |
| ) |
| video_height = (video_height // base) * base |
| video_width = (video_width // base) * base |
|
|
| assert video_height * video_width <= height * width |
|
|
| video = self.video_processor.preprocess_video(video, video_height, video_width) |
| image_size = (video_height, video_width) |
| |
| mask = self.video_processor.preprocess_video(mask, video_height, video_width) |
| mask = torch.clamp((mask + 1) / 2, min=0, max=1) |
|
|
| video = video.to(dtype=dtype, device=device) |
| mask = mask.to(dtype=dtype, device=device) |
|
|
| if reference_image is None: |
| raise ValueError("reference_image must be provided when using IMAGE_CONTROL mode.") |
|
|
| if isinstance(reference_image, (list, tuple)): |
| ref_img_pil = reference_image[0] |
| else: |
| ref_img_pil = reference_image |
|
|
| if reference_mask is not None and isinstance(reference_mask, (list, tuple)): |
| ref_mask_pil = reference_mask[0] |
| else: |
| ref_mask_pil = reference_mask |
|
|
| ref_img_t = self.video_processor.preprocess(ref_img_pil, image_size[0], image_size[1]) |
| if ref_img_t.dim() == 4 and ref_img_t.shape[0] == 1: |
| ref_img_t = ref_img_t[0] |
| if ref_img_t.shape[0] == 1: |
| ref_img_t = ref_img_t.repeat(3, 1, 1) |
| ref_img_t = ref_img_t.to(dtype=dtype, device=device) |
|
|
| H, W = image_size |
| if ref_mask_pil is not None: |
| if not isinstance(ref_mask_pil, Image.Image): |
| ref_mask_pil = Image.fromarray(np.array(ref_mask_pil)) |
| ref_mask_pil = ref_mask_pil.convert("L") |
| ref_mask_pil = ref_mask_pil.resize((W, H), Image.NEAREST) |
| mask_arr = np.array(ref_mask_pil) |
| m = torch.from_numpy(mask_arr).float() / 255.0 |
| m = (m > 0.5).float() |
| ref_msk3 = m.unsqueeze(0).repeat(3, 1, 1) |
| else: |
| ref_msk3 = torch.ones(3, H, W, dtype=dtype) |
|
|
| ref_msk3 = ref_msk3.to(dtype=dtype, device=device) |
|
|
| if math.isclose(reference_patch_ratio, 1.0, rel_tol=1e-6, abs_tol=1e-6): |
| cs_img, cs_m = _compose_centered_foreground( |
| x_f=ref_img_t, |
| m_f3=ref_msk3, |
| target_hw=image_size, |
| bg_value=1.0, |
| ) |
| ref_img_out = cs_img |
| ref_mask_out = cs_m |
| else: |
| patch = _sample_patch_size_from_hw( |
| H=image_size[0], |
| W=image_size[1], |
| ratio=reference_patch_ratio, |
| ) |
|
|
| m_bin = (ref_msk3 > 0.5).float().mean(dim=0) |
| m_bin = (m_bin > 0.5).float() |
| reshuffled, reshuf_mask, used_fb = _masked_patch_pack_to_center_rectangle( |
| x_f=ref_img_t, |
| m_f=m_bin, |
| patch=patch, |
| fg_thresh=fg_thresh, |
| bg_value=1.0, |
| min_patches=4, |
| ) |
|
|
| ref_img_out = reshuffled |
| ref_mask_out = reshuf_mask |
|
|
| B = video.shape[0] |
| if batch_size is not None: |
| B = batch_size |
|
|
| ref_image = ref_img_out.unsqueeze(0).unsqueeze(2).expand(B, -1, -1, -1, -1).contiguous() |
| ref_mask = ref_mask_out.unsqueeze(0).unsqueeze(2).expand(B, 3, -1, -1, -1).contiguous() |
|
|
| ref_image = ref_image.to(dtype=dtype, device=device) |
| ref_mask = ref_mask.to(dtype=dtype, device=device) |
|
|
| return video[:, :, :num_frames], mask[:, :, :num_frames], ref_image, ref_mask |
|
|
| @torch.no_grad() |
| def texture_remove(self, foreground_latent): |
| sample_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=1) |
| text_embedding = torch.zeros( |
| [256, 4096], |
| device=foreground_latent.device, |
| dtype=foreground_latent.dtype |
| ) |
| context = text_embedding.unsqueeze(0).expand( |
| foreground_latent.shape[0], -1, -1 |
| ).to(foreground_latent.device) |
| sample_scheduler.set_timesteps(3, device=foreground_latent.device) |
| timesteps = sample_scheduler.timesteps |
| noise = torch.randn_like( |
| foreground_latent, |
| dtype=foreground_latent.dtype, |
| device=foreground_latent.device |
| ) |
| seq_len = math.ceil( |
| noise.shape[2] * noise.shape[3] * noise.shape[4] / 4 |
| ) |
| latents = noise |
| arg_c = {"context": context, "seq_len": seq_len} |
| with torch.autocast(device_type="cuda", dtype=torch.float16): |
| for _, t in enumerate(timesteps): |
| timestep = torch.stack([t]).to(foreground_latent.device) |
| noise_pred_cond = self.texture_remover( |
| latents, |
| t=timestep, |
| vace_context=foreground_latent, |
| vace_context_scale=1, |
| **arg_c |
| )[0] |
| temp_x0 = sample_scheduler.step( |
| noise_pred_cond, t, latents, return_dict=False |
| )[0] |
| latents = temp_x0 |
| return latents |
| |
| def dilate_mask_hw(self, mask: torch.Tensor, radius: int = 3) -> torch.Tensor: |
| B, C, F_, H, W = mask.shape |
| k = 2 * radius + 1 |
| mask_2d = mask.permute(0, 2, 1, 3, 4).reshape(B * F_, C, H, W) |
| kernel = torch.ones( |
| (C, 1, k, k), |
| device=mask.device, |
| dtype=mask.dtype |
| ) |
| dilated_2d = F.conv2d( |
| mask_2d, |
| weight=kernel, |
| bias=None, |
| stride=1, |
| padding=radius, |
| groups=C |
| ) |
| dilated_2d = (dilated_2d > 0).to(mask.dtype) |
| dilated = dilated_2d.view(B, F_, C, H, W).permute(0, 2, 1, 3, 4) |
| return dilated |
| |
| def prepare_vace_latents( |
| self, |
| dilate_radius: int, |
| video: torch.Tensor, |
| mask: torch.Tensor, |
| reference_image: Optional[torch.Tensor] = None, |
| reference_mask: Optional[torch.Tensor] = None, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| device = device or self._execution_device |
|
|
| vae_dtype = self.vae.dtype |
| video = video.to(dtype=vae_dtype) |
| mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) |
| mask_clone = mask.clone() |
| mask = self.dilate_mask_hw(mask, dilate_radius) |
| inactive = video * (1 - mask) |
| reactive = video * mask_clone |
| reactive_latent = self.vae.encode(reactive) |
| mesh_latent = self.texture_remove(reactive_latent) |
| |
| inactive_latent = self.vae.encode(inactive) |
| ref_latent = self.vae.encode(reference_image) |
| neg_ref_latent = self.vae.encode(torch.ones_like(reference_image)) |
| |
| reference_mask = torch.where(reference_mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) |
| mask = self.vace_encode_masks(mask) |
| ref_mask = self.vace_encode_masks(reference_mask) |
|
|
| return inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask |
|
|
|
|
| def prepare_latents( |
| self, |
| batch_size: int, |
| num_channels_latents: int = 16, |
| height: int = 480, |
| width: int = 832, |
| num_frames: int = 81, |
| dtype: Optional[torch.dtype] = None, |
| device: Optional[torch.device] = None, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if latents is not None: |
| return latents.to(device=device, dtype=dtype) |
|
|
| num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 |
| shape = ( |
| batch_size, |
| num_channels_latents, |
| num_latent_frames, |
| int(height) // self.vae_scale_factor_spatial, |
| int(width) // self.vae_scale_factor_spatial, |
| ) |
| if isinstance(generator, list) and len(generator) != batch_size: |
| raise ValueError( |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| ) |
|
|
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| return latents |
|
|
| @property |
| def guidance_scale(self): |
| return self._guidance_scale |
|
|
| @property |
| def do_classifier_free_guidance(self): |
| return self._guidance_scale > 1.0 |
|
|
| @property |
| def num_timesteps(self): |
| return self._num_timesteps |
|
|
| @property |
| def current_timestep(self): |
| return self._current_timestep |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| video: Optional[PipelineImageInput] = None, |
| mask: Optional[PipelineImageInput] = None, |
| reference_image: Optional[PipelineImageInput] = None, |
| reference_mask: Optional[PipelineImageInput] = None, |
| conditioning_scale: float = 1.0, |
| dilate_radius: int = 3, |
| height: int = 480, |
| width: int = 832, |
| num_frames: int = 81, |
| num_inference_steps: int = 20, |
| guidance_scale: float = 1.5, |
| num_videos_per_prompt: Optional[int] = 1, |
| reference_patch_ratio: float = 0.2, |
| fg_thresh: float = 0.9, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.Tensor] = None, |
| output_type: Optional[str] = "np", |
| return_dict: bool = True, |
| ): |
|
|
| if num_frames % self.vae_scale_factor_temporal != 1: |
| logger.warning( |
| f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." |
| ) |
| num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 |
| num_frames = max(num_frames, 1) |
|
|
|
|
| self._guidance_scale = guidance_scale |
|
|
| device = self._execution_device |
| batch_size = 1 |
|
|
| vae_dtype = self.vae.dtype |
| transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype |
|
|
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
| timesteps = self.scheduler.timesteps |
|
|
| video, mask, reference_image, reference_mask = self.preprocess_conditions( |
| video, |
| mask, |
| reference_image, |
| reference_mask, |
| batch_size, |
| height, |
| width, |
| num_frames, |
| reference_patch_ratio, |
| fg_thresh, |
| torch.float16, |
| device, |
| ) |
|
|
| inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask = self.prepare_vace_latents(dilate_radius, video, mask, reference_image, reference_mask, device) |
| c = torch.cat([inactive_latent, mesh_latent, mask], dim=1) |
| c1 = torch.cat([ref_latent, ref_mask], dim=1) |
| c1_negative = torch.cat( |
| [neg_ref_latent, torch.zeros_like(ref_mask)], |
| dim=1 |
| ) |
|
|
| num_channels_latents = 16 |
| noise = self.prepare_latents( |
| batch_size * num_videos_per_prompt, |
| num_channels_latents, |
| height, |
| width, |
| num_frames, |
| torch.float16, |
| device, |
| generator, |
| latents, |
| ) |
| |
| latents_cond = torch.cat([ref_latent, noise], dim=2) |
| latents_uncond = torch.cat([neg_ref_latent, noise], dim=2) |
|
|
| seq_len = math.ceil( |
| latents_cond.shape[2] * |
| latents_cond.shape[3] * |
| latents_cond.shape[4] / 4 |
| ) |
| seq_len_ref = math.ceil( |
| ref_latent.shape[2] * |
| ref_latent.shape[3] * |
| ref_latent.shape[4] / 4 |
| ) |
| context = self.empty_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device) |
| context_neg = self.negative_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device) |
| arg_c = { |
| "context": context, |
| "seq_len": seq_len, |
| "seq_len_ref": seq_len_ref |
| } |
| arg_c_null = { |
| "context": context_neg, |
| "seq_len": seq_len, |
| "seq_len_ref": seq_len_ref |
| } |
|
|
| self._num_timesteps = len(timesteps) |
|
|
| with self.progress_bar(total=num_inference_steps) as progress_bar: |
| for i, t in enumerate(timesteps): |
| self._current_timestep = t |
| timestep = t.expand(batch_size) |
|
|
| with torch.autocast(device_type="cuda", dtype=torch.float16): |
| noise_pred = self.transformer( |
| latents_cond, |
| t=timestep, |
| vace_context=c, |
| ref_context=c1, |
| vace_context_scale=conditioning_scale, |
| **arg_c, |
| )[0] |
|
|
| if self.do_classifier_free_guidance: |
| noise_pred_uncond = self.transformer( |
| latents_uncond, |
| t=timestep, |
| vace_context=c, |
| ref_context=c1_negative, |
| vace_context_scale=0, |
| **arg_c_null, |
| )[0] |
| noise_pred = (noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)).unsqueeze(0) |
| temp_x0 = self.scheduler.step(noise_pred[:, :, 1:], |
| t, |
| latents_cond[:, :, 1:], |
| return_dict=False)[0] |
| latents_cond = torch.cat([ref_latent, temp_x0], dim=2) |
| latents_uncond = torch.cat([neg_ref_latent, temp_x0], dim=2) |
| progress_bar.update() |
|
|
|
|
| self._current_timestep = None |
|
|
| if not output_type == "latent": |
| latents = temp_x0 |
| latents = latents.to(vae_dtype) |
| video = self.vae.decode(latents) |
| video = self.video_processor.postprocess_video(video, output_type=output_type) |
| mesh = self.vae.decode(mesh_latent.to(vae_dtype)) |
| mesh = self.video_processor.postprocess_video(mesh, output_type=output_type) |
| ref_img = reference_image.cpu().squeeze(0).squeeze(1).permute(1, 2, 0).numpy() |
| ref_img = ((ref_img+1)*255/2).astype(np.uint8) |
| else: |
| video = temp_x0 |
| mesh = mesh_latent |
| ref_img = ref_latent |
|
|
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return (video, mesh, ref_img) |
|
|
| return RefacadePipelineOutput(frames=video, meshes=mesh, ref_img=ref_img) |
|
|