| import os, gc, math |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import hashlib |
|
|
| from .wanvideo.schedulers import get_scheduler, scheduler_list |
|
|
| from .utils import(log, clip_encode_image_tiled, add_noise_to_reference_video, set_module_tensor_to_device) |
| from .taehv import TAEHV |
|
|
| from comfy import model_management as mm |
| from comfy.utils import ProgressBar, common_upscale |
| from comfy.clip_vision import clip_preprocess, ClipVisionModel |
| import folder_paths |
|
|
| script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
| device = mm.get_torch_device() |
| offload_device = mm.unet_offload_device() |
|
|
| VAE_STRIDE = (4, 8, 8) |
| PATCH_SIZE = (1, 2, 2) |
|
|
|
|
| class WanVideoEnhanceAVideo: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}), |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}), |
| }, |
| } |
| RETURN_TYPES = ("FETAARGS",) |
| RETURN_NAMES = ("feta_args",) |
| FUNCTION = "setargs" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" |
|
|
| def setargs(self, **kwargs): |
| return (kwargs, ) |
|
|
| class WanVideoSetBlockSwap: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("WANVIDEOMODEL", ), |
| }, |
| "optional": { |
| "block_swap_args": ("BLOCKSWAPARGS", ), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDEOMODEL",) |
| RETURN_NAMES = ("model", ) |
| FUNCTION = "loadmodel" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def loadmodel(self, model, block_swap_args=None): |
| if block_swap_args is None: |
| return (model,) |
| patcher = model.clone() |
| if 'transformer_options' not in patcher.model_options: |
| patcher.model_options['transformer_options'] = {} |
| patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args |
|
|
| return (patcher,) |
|
|
| class WanVideoSetRadialAttention: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("WANVIDEOMODEL", ), |
| "dense_attention_mode": ([ |
| "sdpa", |
| "flash_attn_2", |
| "flash_attn_3", |
| "sageattn", |
| "sparse_sage_attention", |
| ], {"default": "sageattn", "tooltip": "The attention mode for dense attention"}), |
| "dense_blocks": ("INT", {"default": 1, "min": 0, "max": 40, "step": 1, "tooltip": "Number of blocks to apply normal attention to"}), |
| "dense_vace_blocks": ("INT", {"default": 1, "min": 0, "max": 15, "step": 1, "tooltip": "Number of vace blocks to apply normal attention to"}), |
| "dense_timesteps": ("INT", {"default": 2, "min": 0, "max": 100, "step": 1, "tooltip": "The step to start applying sparse attention"}), |
| "decay_factor": ("FLOAT", {"default": 0.2, "min": 0, "max": 1, "step": 0.01, "tooltip": "Controls how quickly the attention window shrinks as the distance between frames increases in the sparse attention mask."}), |
| "block_size":([128, 64], {"default": 128, "tooltip": "Radial attention block size, larger blocks are faster but restricts usable dimensions more."}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDEOMODEL",) |
| RETURN_NAMES = ("model", ) |
| FUNCTION = "loadmodel" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Sets radial attention parameters, dense attention refers to normal attention" |
|
|
| def loadmodel(self, model, dense_attention_mode, dense_blocks, dense_vace_blocks, dense_timesteps, decay_factor, block_size): |
| if "radial" not in model.model.diffusion_model.attention_mode: |
| raise Exception("Enable radial attention first in the model loader.") |
| |
| patcher = model.clone() |
| if 'transformer_options' not in patcher.model_options: |
| patcher.model_options['transformer_options'] = {} |
|
|
| patcher.model_options["transformer_options"]["dense_attention_mode"] = dense_attention_mode |
| patcher.model_options["transformer_options"]["dense_blocks"] = dense_blocks |
| patcher.model_options["transformer_options"]["dense_vace_blocks"] = dense_vace_blocks |
| patcher.model_options["transformer_options"]["dense_timesteps"] = dense_timesteps |
| patcher.model_options["transformer_options"]["decay_factor"] = decay_factor |
| patcher.model_options["transformer_options"]["block_size"] = block_size |
|
|
| return (patcher,) |
|
|
| class WanVideoBlockList: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "blocks": ("STRING", {"default": "1", "multiline":True}), |
| } |
| } |
|
|
| RETURN_TYPES = ("INT",) |
| RETURN_NAMES = ("block_list", ) |
| FUNCTION = "create_list" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Comma separated list of blocks to apply block swap to, can also use ranges like '0-5' or '0,2,3-5' etc., can be connected to the dense_blocks input of 'WanVideoSetRadialAttention' node" |
|
|
| def create_list(self, blocks): |
| block_list = [] |
| for line in blocks.splitlines(): |
| for part in line.split(","): |
| part = part.strip() |
| if not part: |
| continue |
| if "-" in part: |
| try: |
| start, end = map(int, part.split("-", 1)) |
| block_list.extend(range(start, end + 1)) |
| except Exception: |
| raise ValueError(f"Invalid range: '{part}'") |
| else: |
| try: |
| block_list.append(int(part)) |
| except Exception: |
| raise ValueError(f"Invalid integer: '{part}'") |
| return (block_list,) |
|
|
|
|
|
|
| |
| _extender_cache = {} |
|
|
| cache_dir = os.path.join(script_directory, 'text_embed_cache') |
|
|
| def get_cache_path(prompt): |
| cache_key = prompt.strip() |
| cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest() |
| return os.path.join(cache_dir, f"{cache_hash}.pt") |
|
|
| def get_cached_text_embeds(positive_prompt, negative_prompt): |
| |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| context = None |
| context_null = None |
|
|
| pos_cache_path = get_cache_path(positive_prompt) |
| neg_cache_path = get_cache_path(negative_prompt) |
|
|
| |
| if os.path.exists(pos_cache_path): |
| try: |
| log.info(f"Loading prompt embeds from cache: {pos_cache_path}") |
| context = torch.load(pos_cache_path) |
| except Exception as e: |
| log.warning(f"Failed to load cache: {e}, will re-encode.") |
|
|
| |
| if os.path.exists(neg_cache_path): |
| try: |
| log.info(f"Loading prompt embeds from cache: {neg_cache_path}") |
| context_null = torch.load(neg_cache_path) |
| except Exception as e: |
| log.warning(f"Failed to load cache: {e}, will re-encode.") |
|
|
| return context, context_null |
|
|
| class WanVideoTextEncodeCached: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "model_name": (folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/text_encoders'"}), |
| "precision": (["fp32", "bf16"], |
| {"default": "bf16"} |
| ), |
| "positive_prompt": ("STRING", {"default": "", "multiline": True} ), |
| "negative_prompt": ("STRING", {"default": "", "multiline": True} ), |
| "quantization": (['disabled', 'fp8_e4m3fn'], {"default": 'disabled', "tooltip": "optional quantization method"}), |
| "use_disk_cache": ("BOOLEAN", {"default": True, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), |
| "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), |
| }, |
| "optional": { |
| "extender_args": ("WANVIDEOPROMPTEXTENDER_ARGS", {"tooltip": "Use this node to extend the prompt with additional text."}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", "WANVIDEOTEXTEMBEDS", "STRING") |
| RETURN_NAMES = ("text_embeds", "negative_text_embeds", "positive_prompt") |
| OUTPUT_TOOLTIPS = ("The text embeddings for both prompts", "The text embeddings for the negative prompt only (for NAG)", "Positive prompt to display prompt extender results") |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = """Encodes text prompts into text embeddings. This node loads and completely unloads the T5 after done, |
| leaving no VRAM or RAM imprint. If prompts have been cached before T5 is not loaded at all. |
| negative output is meant to be used with NAG, it contains only negative prompt embeddings. |
| |
| Additionally you can provide a Qwen LLM model to extend the positive prompt with either one |
| of the original Wan templates or a custom system prompt. |
| """ |
|
|
|
|
| def process(self, model_name, precision, positive_prompt, negative_prompt, quantization='disabled', use_disk_cache=True, device="gpu", extender_args=None): |
| from .nodes_model_loading import LoadWanVideoT5TextEncoder |
| pbar = ProgressBar(3) |
|
|
| echoshot = True if "[1]" in positive_prompt else False |
|
|
| |
| orig_prompt = positive_prompt |
| if extender_args is not None: |
| extender_key = (orig_prompt, str(extender_args)) |
| if extender_key in _extender_cache: |
| positive_prompt = _extender_cache[extender_key] |
| log.info(f"Loaded extended prompt from in-memory cache: {positive_prompt}") |
| else: |
| from .qwen.qwen import QwenLoader, WanVideoPromptExtender |
| log.info("Using WanVideoPromptExtender to process prompts") |
| qwen, = QwenLoader().load( |
| extender_args["model"], |
| load_device="main_device" if device == "gpu" else "cpu", |
| precision=precision) |
| positive_prompt, = WanVideoPromptExtender().generate( |
| qwen=qwen, |
| max_new_tokens=extender_args["max_new_tokens"], |
| prompt=orig_prompt, |
| device=device, |
| force_offload=False, |
| custom_system_prompt=extender_args["system_prompt"], |
| seed=extender_args["seed"] |
| ) |
| log.info(f"Extended positive prompt: {positive_prompt}") |
| _extender_cache[extender_key] = positive_prompt |
| del qwen |
| pbar.update(1) |
|
|
| |
| if use_disk_cache: |
| context, context_null = get_cached_text_embeds(positive_prompt, negative_prompt) |
| if context is not None and context_null is not None: |
| return{ |
| "prompt_embeds": context, |
| "negative_prompt_embeds": context_null, |
| "echoshot": echoshot, |
| },{"prompt_embeds": context_null}, positive_prompt |
|
|
| t5, = LoadWanVideoT5TextEncoder().loadmodel(model_name, precision, "main_device", quantization) |
| pbar.update(1) |
|
|
| prompt_embeds_dict, = WanVideoTextEncode().process( |
| positive_prompt=positive_prompt, |
| negative_prompt=negative_prompt, |
| t5=t5, |
| force_offload=False, |
| model_to_offload=None, |
| use_disk_cache=use_disk_cache, |
| device=device |
| ) |
| pbar.update(1) |
| del t5 |
| mm.soft_empty_cache() |
| gc.collect() |
| return (prompt_embeds_dict, {"prompt_embeds": prompt_embeds_dict["negative_prompt_embeds"]}, positive_prompt) |
|
|
| |
| class WanVideoTextEncode: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "positive_prompt": ("STRING", {"default": "", "multiline": True} ), |
| "negative_prompt": ("STRING", {"default": "", "multiline": True} ), |
| }, |
| "optional": { |
| "t5": ("WANTEXTENCODER",), |
| "force_offload": ("BOOLEAN", {"default": True}), |
| "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), |
| "use_disk_cache": ("BOOLEAN", {"default": False, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), |
| "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) |
| RETURN_NAMES = ("text_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Encodes text prompts into text embeddings. For rudimentary prompt travel you can input multiple prompts separated by '|', they will be equally spread over the video length" |
|
|
|
|
| def process(self, positive_prompt, negative_prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"): |
| if t5 is None and not use_disk_cache: |
| raise ValueError("T5 encoder is required for text encoding. Please provide a valid T5 encoder or enable disk cache.") |
|
|
| echoshot = True if "[1]" in positive_prompt else False |
|
|
| if use_disk_cache: |
| context, context_null = get_cached_text_embeds(positive_prompt, negative_prompt) |
| if context is not None and context_null is not None: |
| return{ |
| "prompt_embeds": context, |
| "negative_prompt_embeds": context_null, |
| "echoshot": echoshot, |
| }, |
| |
| if t5 is None: |
| raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.") |
|
|
| if model_to_offload is not None and device == "gpu": |
| try: |
| log.info(f"Moving video model to {offload_device}") |
| model_to_offload.model.to(offload_device) |
| except: |
| pass |
|
|
| encoder = t5["model"] |
| dtype = t5["dtype"] |
| |
| positive_prompts = [] |
| all_weights = [] |
|
|
| |
| if "|" in positive_prompt: |
| log.info("Multiple positive prompts detected, splitting by '|'") |
| positive_prompts_raw = [p.strip() for p in positive_prompt.split('|')] |
| elif "[1]" in positive_prompt: |
| log.info("Multiple positive prompts detected, splitting by [#] and enabling EchoShot") |
| import re |
| segments = re.split(r'\[\d+\]', positive_prompt) |
| positive_prompts_raw = [segment.strip() for segment in segments if segment.strip()] |
| assert len(positive_prompts_raw) > 1 and len(positive_prompts_raw) < 7, 'Input shot num must between 2~6 !' |
| else: |
| positive_prompts_raw = [positive_prompt.strip()] |
| |
| for p in positive_prompts_raw: |
| cleaned_prompt, weights = self.parse_prompt_weights(p) |
| positive_prompts.append(cleaned_prompt) |
| all_weights.append(weights) |
|
|
| mm.soft_empty_cache() |
|
|
| if device == "gpu": |
| device_to = mm.get_torch_device() |
| else: |
| device_to = torch.device("cpu") |
|
|
| if encoder.quantization == "fp8_e4m3fn": |
| cast_dtype = torch.float8_e4m3fn |
| else: |
| cast_dtype = encoder.dtype |
|
|
| params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} |
| for name, param in encoder.model.named_parameters(): |
| dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype |
| value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name] |
| set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value) |
| if hasattr(encoder, 'state_dict'): |
| del encoder.state_dict |
| mm.soft_empty_cache() |
| gc.collect() |
|
|
| with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'): |
| |
| if use_disk_cache and context is not None: |
| pass |
| else: |
| context = encoder(positive_prompts, device_to) |
| |
| for i, weights in enumerate(all_weights): |
| for text, weight in weights.items(): |
| log.info(f"Applying weight {weight} to prompt: {text}") |
| if len(weights) > 0: |
| context[i] = context[i] * weight |
|
|
| |
| if use_disk_cache and context_null is not None: |
| pass |
| else: |
| context_null = encoder([negative_prompt], device_to) |
|
|
| if force_offload: |
| encoder.model.to(offload_device) |
| mm.soft_empty_cache() |
| gc.collect() |
|
|
| prompt_embeds_dict = { |
| "prompt_embeds": context, |
| "negative_prompt_embeds": context_null, |
| "echoshot": echoshot, |
| } |
|
|
| |
| if use_disk_cache: |
| pos_cache_path = get_cache_path(positive_prompt) |
| neg_cache_path = get_cache_path(negative_prompt) |
| try: |
| if not os.path.exists(pos_cache_path): |
| torch.save(context, pos_cache_path) |
| log.info(f"Saved prompt embeds to cache: {pos_cache_path}") |
| except Exception as e: |
| log.warning(f"Failed to save cache: {e}") |
| try: |
| if not os.path.exists(neg_cache_path): |
| torch.save(context_null, neg_cache_path) |
| log.info(f"Saved prompt embeds to cache: {neg_cache_path}") |
| except Exception as e: |
| log.warning(f"Failed to save cache: {e}") |
|
|
| return (prompt_embeds_dict,) |
| |
| def parse_prompt_weights(self, prompt): |
| """Extract text and weights from prompts with (text:weight) format""" |
| import re |
| |
| |
| pattern = r'\((.*?):([\d\.]+)\)' |
| matches = re.findall(pattern, prompt) |
| |
| |
| cleaned_prompt = prompt |
| weights = {} |
| |
| for match in matches: |
| text, weight = match |
| orig_text = f"({text}:{weight})" |
| cleaned_prompt = cleaned_prompt.replace(orig_text, text) |
| weights[text] = float(weight) |
| |
| return cleaned_prompt, weights |
| |
| class WanVideoTextEncodeSingle: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "prompt": ("STRING", {"default": "", "multiline": True} ), |
| }, |
| "optional": { |
| "t5": ("WANTEXTENCODER",), |
| "force_offload": ("BOOLEAN", {"default": True}), |
| "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), |
| "use_disk_cache": ("BOOLEAN", {"default": False, "tooltip": "Cache the text embeddings to disk for faster re-use, under the custom_nodes/ComfyUI-WanVideoWrapper/text_embed_cache directory"}), |
| "device": (["gpu", "cpu"], {"default": "gpu", "tooltip": "Device to run the text encoding on."}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) |
| RETURN_NAMES = ("text_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Encodes text prompt into text embedding." |
|
|
| def process(self, prompt, t5=None, force_offload=True, model_to_offload=None, use_disk_cache=False, device="gpu"): |
| |
| encoded = None |
| echoshot = True if "[1]" in prompt else False |
| if use_disk_cache: |
| cache_dir = os.path.join(script_directory, 'text_embed_cache') |
| os.makedirs(cache_dir, exist_ok=True) |
| def get_cache_path(prompt): |
| cache_key = prompt.strip() |
| cache_hash = hashlib.sha256(cache_key.encode('utf-8')).hexdigest() |
| return os.path.join(cache_dir, f"{cache_hash}.pt") |
| cache_path = get_cache_path(prompt) |
| if os.path.exists(cache_path): |
| try: |
| log.info(f"Loading prompt embeds from cache: {cache_path}") |
| encoded = torch.load(cache_path) |
| except Exception as e: |
| log.warning(f"Failed to load cache: {e}, will re-encode.") |
|
|
| if t5 is None and encoded is None: |
| raise ValueError("No cached text embeds found for prompts, please provide a T5 encoder.") |
|
|
| if encoded is None: |
| try: |
| if model_to_offload is not None and device == "gpu": |
| log.info(f"Moving video model to {offload_device}") |
| model_to_offload.model.to(offload_device) |
| mm.soft_empty_cache() |
| except: |
| pass |
|
|
| encoder = t5["model"] |
| dtype = t5["dtype"] |
|
|
| if device == "gpu": |
| device_to = mm.get_torch_device() |
| else: |
| device_to = torch.device("cpu") |
|
|
| if encoder.quantization == "fp8_e4m3fn": |
| cast_dtype = torch.float8_e4m3fn |
| else: |
| cast_dtype = encoder.dtype |
| params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} |
| for name, param in encoder.model.named_parameters(): |
| dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype |
| value = encoder.state_dict[name] if hasattr(encoder, 'state_dict') else encoder.model.state_dict()[name] |
| set_module_tensor_to_device(encoder.model, name, device=device_to, dtype=dtype_to_use, value=value) |
| if hasattr(encoder, 'state_dict'): |
| del encoder.state_dict |
| mm.soft_empty_cache() |
| gc.collect() |
| with torch.autocast(device_type=mm.get_autocast_device(device_to), dtype=encoder.dtype, enabled=encoder.quantization != 'disabled'): |
| encoded = encoder([prompt], device_to) |
|
|
| if force_offload: |
| encoder.model.to(offload_device) |
| mm.soft_empty_cache() |
|
|
| |
| if use_disk_cache: |
| try: |
| if not os.path.exists(cache_path): |
| torch.save(encoded, cache_path) |
| log.info(f"Saved prompt embeds to cache: {cache_path}") |
| except Exception as e: |
| log.warning(f"Failed to save cache: {e}") |
|
|
| prompt_embeds_dict = { |
| "prompt_embeds": encoded, |
| "negative_prompt_embeds": None, |
| "echoshot": echoshot |
| } |
| return (prompt_embeds_dict,) |
| |
| class WanVideoApplyNAG: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "original_text_embeds": ("WANVIDEOTEXTEMBEDS",), |
| "nag_text_embeds": ("WANVIDEOTEXTEMBEDS",), |
| "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.1}), |
| "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.1}), |
| "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) |
| RETURN_NAMES = ("text_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Adds NAG prompt embeds to original prompt embeds: 'https://github.com/ChenDarYen/Normalized-Attention-Guidance'" |
|
|
| def process(self, original_text_embeds, nag_text_embeds, nag_scale, nag_tau, nag_alpha): |
| prompt_embeds_dict_copy = original_text_embeds.copy() |
| prompt_embeds_dict_copy.update({ |
| "nag_prompt_embeds": nag_text_embeds["prompt_embeds"], |
| "nag_params": { |
| "nag_scale": nag_scale, |
| "nag_tau": nag_tau, |
| "nag_alpha": nag_alpha, |
| } |
| }) |
| return (prompt_embeds_dict_copy,) |
| |
| class WanVideoTextEmbedBridge: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "positive": ("CONDITIONING",), |
| }, |
| "optional": { |
| "negative": ("CONDITIONING",), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) |
| RETURN_NAMES = ("text_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Bridge between ComfyUI native text embedding and WanVideoWrapper text embedding" |
|
|
| def process(self, positive, negative=None): |
| prompt_embeds_dict = { |
| "prompt_embeds": positive[0][0].to(device), |
| "negative_prompt_embeds": negative[0][0].to(device) if negative is not None else None, |
| } |
| return (prompt_embeds_dict,) |
| |
| |
| class WanVideoClipVisionEncode: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "clip_vision": ("CLIP_VISION",), |
| "image_1": ("IMAGE", {"tooltip": "Image to encode"}), |
| "strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), |
| "strength_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), |
| "crop": (["center", "disabled"], {"default": "center", "tooltip": "Crop image to 224x224 before encoding"}), |
| "combine_embeds": (["average", "sum", "concat", "batch"], {"default": "average", "tooltip": "Method to combine multiple clip embeds"}), |
| "force_offload": ("BOOLEAN", {"default": True}), |
| }, |
| "optional": { |
| "image_2": ("IMAGE", ), |
| "negative_image": ("IMAGE", {"tooltip": "image to use for uncond"}), |
| "tiles": ("INT", {"default": 0, "min": 0, "max": 16, "step": 2, "tooltip": "Use matteo's tiled image encoding for improved accuracy"}), |
| "ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ratio of the tile average"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_CLIPEMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, clip_vision, image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2=None, negative_image=None, tiles=0, ratio=1.0): |
| image_mean = [0.48145466, 0.4578275, 0.40821073] |
| image_std = [0.26862954, 0.26130258, 0.27577711] |
|
|
| if image_2 is not None: |
| image = torch.cat([image_1, image_2], dim=0) |
| else: |
| image = image_1 |
|
|
| clip_vision.model.to(device) |
| |
| negative_clip_embeds = None |
|
|
| if tiles > 0: |
| log.info("Using tiled image encoding") |
| clip_embeds = clip_encode_image_tiled(clip_vision, image.to(device), tiles=tiles, ratio=ratio) |
| if negative_image is not None: |
| negative_clip_embeds = clip_encode_image_tiled(clip_vision, negative_image.to(device), tiles=tiles, ratio=ratio) |
| else: |
| if isinstance(clip_vision, ClipVisionModel): |
| clip_embeds = clip_vision.encode_image(image).penultimate_hidden_states.to(device) |
| if negative_image is not None: |
| negative_clip_embeds = clip_vision.encode_image(negative_image).penultimate_hidden_states.to(device) |
| else: |
| pixel_values = clip_preprocess(image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() |
| clip_embeds = clip_vision.visual(pixel_values) |
| if negative_image is not None: |
| pixel_values = clip_preprocess(negative_image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() |
| negative_clip_embeds = clip_vision.visual(pixel_values) |
| |
| log.info(f"Clip embeds shape: {clip_embeds.shape}, dtype: {clip_embeds.dtype}") |
|
|
| weighted_embeds = [] |
| weighted_embeds.append(clip_embeds[0:1] * strength_1) |
|
|
| |
| if clip_embeds.shape[0] > 1: |
| weighted_embeds.append(clip_embeds[1:2] * strength_2) |
| |
| if clip_embeds.shape[0] > 2: |
| for i in range(2, clip_embeds.shape[0]): |
| weighted_embeds.append(clip_embeds[i:i+1]) |
| |
| |
| if combine_embeds == "average": |
| clip_embeds = torch.mean(torch.stack(weighted_embeds), dim=0) |
| elif combine_embeds == "sum": |
| clip_embeds = torch.sum(torch.stack(weighted_embeds), dim=0) |
| elif combine_embeds == "concat": |
| clip_embeds = torch.cat(weighted_embeds, dim=1) |
| elif combine_embeds == "batch": |
| clip_embeds = torch.cat(weighted_embeds, dim=0) |
| else: |
| clip_embeds = weighted_embeds[0] |
| |
|
|
| log.info(f"Combined clip embeds shape: {clip_embeds.shape}") |
| |
| if force_offload: |
| clip_vision.model.to(offload_device) |
| mm.soft_empty_cache() |
|
|
| clip_embeds_dict = { |
| "clip_embeds": clip_embeds, |
| "negative_clip_embeds": negative_clip_embeds |
| } |
|
|
| return (clip_embeds_dict,) |
| |
| class WanVideoRealisDanceLatents: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "ref_latent": ("LATENT", {"tooltip": "Reference image to encode"}), |
| "pose_cond_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the SMPL model"}), |
| "pose_cond_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the SMPL model"}), |
| }, |
| "optional": { |
| "smpl_latent": ("LATENT", {"tooltip": "SMPL pose image to encode"}), |
| "hamer_latent": ("LATENT", {"tooltip": "Hamer hand pose image to encode"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("ADD_COND_LATENTS",) |
| RETURN_NAMES = ("add_cond_latents",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, ref_latent, pose_cond_start_percent, pose_cond_end_percent, hamer_latent=None, smpl_latent=None): |
| if smpl_latent is None and hamer_latent is None: |
| raise Exception("At least one of smpl_latent or hamer_latent must be provided") |
| if smpl_latent is None: |
| smpl = torch.zeros_like(hamer_latent["samples"]) |
| else: |
| smpl = smpl_latent["samples"] |
| if hamer_latent is None: |
| hamer = torch.zeros_like(smpl_latent["samples"]) |
| else: |
| hamer = hamer_latent["samples"] |
|
|
| pose_latent = torch.cat((smpl, hamer), dim=1) |
| |
| add_cond_latents = { |
| "ref_latent": ref_latent["samples"], |
| "pose_latent": pose_latent, |
| "pose_cond_start_percent": pose_cond_start_percent, |
| "pose_cond_end_percent": pose_cond_end_percent, |
| } |
|
|
| return (add_cond_latents,) |
|
|
| |
| class WanVideoAddStandInLatent: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "embeds": ("WANVIDIMAGE_EMBEDS",), |
| "ip_image_latent": ("LATENT", {"tooltip": "Reference image to encode"}), |
| "freq_offset": ("INT", {"default": 1, "min": 0, "max": 100, "step": 1, "tooltip": "EXPERIMENTAL: RoPE frequency offset between the reference and rest of the sequence"}), |
| |
| |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "add" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def add(self, embeds, ip_image_latent, freq_offset): |
| |
| new_entry = { |
| "ip_image_latent": ip_image_latent["samples"], |
| "freq_offset": freq_offset, |
| |
| |
| } |
|
|
| |
| updated = dict(embeds) |
| updated["standin_input"] = new_entry |
| return (updated,) |
|
|
| class WanVideoAddMTVMotion: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "embeds": ("WANVIDIMAGE_EMBEDS",), |
| "mtv_crafter_motion": ("MTVCRAFTERMOTION",), |
| "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the MTV motion"}), |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "add" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def add(self, embeds, mtv_crafter_motion, strength, start_percent, end_percent): |
| |
| new_entry = { |
| "mtv_motion_tokens": mtv_crafter_motion["mtv_motion_tokens"], |
| "strength": strength, |
| "start_percent": start_percent, |
| "end_percent": end_percent, |
| "global_mean": mtv_crafter_motion["global_mean"], |
| "global_std": mtv_crafter_motion["global_std"] |
| } |
|
|
| |
| updated = dict(embeds) |
| updated["mtv_crafter_motion"] = new_entry |
| return (updated,) |
|
|
| |
| class WanVideoImageToVideoEncode: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), |
| "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for I2V where some noise can add motion and give sharper results"}), |
| "start_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), |
| "end_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), |
| "force_offload": ("BOOLEAN", {"default": True}), |
| }, |
| "optional": { |
| "vae": ("WANVAE",), |
| "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), |
| "start_image": ("IMAGE", {"tooltip": "Image to encode"}), |
| "end_image": ("IMAGE", {"tooltip": "end frame"}), |
| "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "Control signal for the Fun -model"}), |
| "fun_or_fl2v_model": ("BOOLEAN", {"default": True, "tooltip": "Enable when using official FLF2V or Fun model"}), |
| "temporal_mask": ("MASK", {"tooltip": "mask"}), |
| "extra_latents": ("LATENT", {"tooltip": "Extra latents to add to the input front, used for Skyreels A2 reference images"}), |
| "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), |
| "add_cond_latents": ("ADD_COND_LATENTS", {"advanced": True, "tooltip": "Additional cond latents WIP"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, width, height, num_frames, force_offload, noise_aug_strength, |
| start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False, |
| temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, add_cond_latents=None, vae=None): |
| |
| if start_image is None and end_image is None and add_cond_latents is None: |
| return WanVideoEmptyEmbeds().process( |
| num_frames, width, height, control_embeds=control_embeds, extra_latents=extra_latents, |
| ) |
| if vae is None: |
| raise ValueError("VAE is required for image encoding.") |
| H = height |
| W = width |
| |
| lat_h = H // vae.upsampling_factor |
| lat_w = W // vae.upsampling_factor |
|
|
| num_frames = ((num_frames - 1) // 4) * 4 + 1 |
| two_ref_images = start_image is not None and end_image is not None |
|
|
| if start_image is None and end_image is not None: |
| fun_or_fl2v_model = True |
|
|
| base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0) |
| if temporal_mask is None: |
| mask = torch.zeros(1, base_frames, lat_h, lat_w, device=device, dtype=vae.dtype) |
| if start_image is not None: |
| mask[:, 0:start_image.shape[0]] = 1 |
| if end_image is not None: |
| mask[:, -end_image.shape[0]:] = 1 |
| else: |
| mask = common_upscale(temporal_mask.unsqueeze(1).to(device), lat_w, lat_h, "nearest", "disabled").squeeze(1) |
| if mask.shape[0] > base_frames: |
| mask = mask[:base_frames] |
| elif mask.shape[0] < base_frames: |
| mask = torch.cat([mask, torch.zeros(base_frames - mask.shape[0], lat_h, lat_w, device=device)]) |
| mask = mask.unsqueeze(0).to(device, vae.dtype) |
|
|
| |
| start_mask_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) |
| if end_image is not None and not fun_or_fl2v_model: |
| end_mask_repeated = torch.repeat_interleave(mask[:, -1:], repeats=4, dim=1) |
| mask = torch.cat([start_mask_repeated, mask[:, 1:-1], end_mask_repeated], dim=1) |
| else: |
| mask = torch.cat([start_mask_repeated, mask[:, 1:]], dim=1) |
|
|
| |
| mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) |
| mask = mask.movedim(1, 2)[0] |
|
|
| |
| if start_image is not None: |
| start_image = start_image[..., :3] |
| if start_image.shape[1] != H or start_image.shape[2] != W: |
| resized_start_image = common_upscale(start_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) |
| else: |
| resized_start_image = start_image.permute(3, 0, 1, 2) |
| resized_start_image = resized_start_image * 2 - 1 |
| if noise_aug_strength > 0.0: |
| resized_start_image = add_noise_to_reference_video(resized_start_image, ratio=noise_aug_strength) |
| |
| if end_image is not None: |
| end_image = end_image[..., :3] |
| if end_image.shape[1] != H or end_image.shape[2] != W: |
| resized_end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) |
| else: |
| resized_end_image = end_image.permute(3, 0, 1, 2) |
| resized_end_image = resized_end_image * 2 - 1 |
| if noise_aug_strength > 0.0: |
| resized_end_image = add_noise_to_reference_video(resized_end_image, ratio=noise_aug_strength) |
| |
| |
| if temporal_mask is None: |
| if start_image is not None and end_image is None: |
| zero_frames = torch.zeros(3, num_frames-start_image.shape[0], H, W, device=device, dtype=vae.dtype) |
| concatenated = torch.cat([resized_start_image.to(device, dtype=vae.dtype), zero_frames], dim=1) |
| del resized_start_image, zero_frames |
| elif start_image is None and end_image is not None: |
| zero_frames = torch.zeros(3, num_frames-end_image.shape[0], H, W, device=device, dtype=vae.dtype) |
| concatenated = torch.cat([zero_frames, resized_end_image.to(device, dtype=vae.dtype)], dim=1) |
| del zero_frames |
| elif start_image is None and end_image is None: |
| concatenated = torch.zeros(3, num_frames, H, W, device=device, dtype=vae.dtype) |
| else: |
| if fun_or_fl2v_model: |
| zero_frames = torch.zeros(3, num_frames-(start_image.shape[0]+end_image.shape[0]), H, W, device=device, dtype=vae.dtype) |
| else: |
| zero_frames = torch.zeros(3, num_frames-1, H, W, device=device, dtype=vae.dtype) |
| concatenated = torch.cat([resized_start_image.to(device, dtype=vae.dtype), zero_frames, resized_end_image.to(device, dtype=vae.dtype)], dim=1) |
| del resized_start_image, zero_frames |
| else: |
| temporal_mask = common_upscale(temporal_mask.unsqueeze(1), W, H, "nearest", "disabled").squeeze(1) |
| concatenated = resized_start_image[:,:num_frames].to(vae.dtype) * temporal_mask[:num_frames].unsqueeze(0).to(vae.dtype) |
| del resized_start_image, temporal_mask |
|
|
| mm.soft_empty_cache() |
| gc.collect() |
|
|
| vae.to(device) |
| y = vae.encode([concatenated], device, end_=(end_image is not None and not fun_or_fl2v_model),tiled=tiled_vae)[0] |
| del concatenated |
|
|
| has_ref = False |
| if extra_latents is not None: |
| samples = extra_latents["samples"].squeeze(0) |
| y = torch.cat([samples, y], dim=1) |
| mask = torch.cat([torch.ones_like(mask[:, 0:samples.shape[1]]), mask], dim=1) |
| num_frames += samples.shape[1] * 4 |
| has_ref = True |
| y[:, :1] *= start_latent_strength |
| y[:, -1:] *= end_latent_strength |
|
|
| |
| patches_per_frame = lat_h * lat_w // (PATCH_SIZE[1] * PATCH_SIZE[2]) |
| frames_per_stride = (num_frames - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1) |
| max_seq_len = frames_per_stride * patches_per_frame |
|
|
| if add_cond_latents is not None: |
| add_cond_latents["ref_latent_neg"] = vae.encode(torch.zeros(1, 3, 1, H, W, device=device, dtype=vae.dtype), device) |
| |
| if force_offload: |
| vae.model.to(offload_device) |
| mm.soft_empty_cache() |
| gc.collect() |
|
|
| image_embeds = { |
| "image_embeds": y, |
| "clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None, |
| "negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None, |
| "max_seq_len": max_seq_len, |
| "num_frames": num_frames, |
| "lat_h": lat_h, |
| "lat_w": lat_w, |
| "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, |
| "end_image": resized_end_image if end_image is not None else None, |
| "fun_or_fl2v_model": fun_or_fl2v_model, |
| "has_ref": has_ref, |
| "add_cond_latents": add_cond_latents, |
| "mask": mask |
| } |
|
|
| return (image_embeds,) |
| |
| |
| class WanVideoAnimateEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "vae": ("WANVAE",), |
| "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), |
| "force_offload": ("BOOLEAN", {"default": True}), |
| "frame_window_size": ("INT", {"default": 77, "min": 1, "max": 1000, "step": 1, "tooltip": "Number of frames to use for temporal attention window"}), |
| "colormatch": ( |
| [ |
| 'disabled', |
| 'mkl', |
| 'hm', |
| 'reinhard', |
| 'mvgd', |
| 'hm-mvgd-hm', |
| 'hm-mkl-hm', |
| ], { |
| "default": 'disabled', "tooltip": "Color matching method to use between the windows" |
| },), |
| "pose_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the pose"}), |
| "face_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional multiplier for the face"}), |
| }, |
| "optional": { |
| "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), |
| "ref_images": ("IMAGE", {"tooltip": "Image to encode"}), |
| "pose_images": ("IMAGE", {"tooltip": "end frame"}), |
| "face_images": ("IMAGE", {"tooltip": "end frame"}), |
| "bg_images": ("IMAGE", {"tooltip": "background images"}), |
| "mask": ("MASK", {"tooltip": "mask"}), |
| "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, vae, width, height, num_frames, force_offload, frame_window_size, colormatch, pose_strength, face_strength, |
| ref_images=None, pose_images=None, face_images=None, clip_embeds=None, tiled_vae=False, bg_images=None, mask=None): |
|
|
| H = height |
| W = width |
|
|
| lat_h = H // vae.upsampling_factor |
| lat_w = W // vae.upsampling_factor |
|
|
| num_refs = ref_images.shape[0] if ref_images is not None else 0 |
| num_frames = ((num_frames - 1) // 4) * 4 + 1 |
|
|
| looping = num_frames > frame_window_size |
|
|
| if num_frames < frame_window_size: |
| frame_window_size = num_frames |
|
|
| target_shape = (16, (num_frames - 1) // 4 + 1 + num_refs, lat_h, lat_w) |
| latent_window_size = ((frame_window_size - 1) // 4) |
|
|
| if not looping: |
| num_frames = num_frames + num_refs * 4 |
| else: |
| latent_window_size = latent_window_size + 1 |
|
|
| mm.soft_empty_cache() |
| gc.collect() |
| vae.to(device) |
| |
| pose_latents = ref_latents = ref_latent = None |
| if pose_images is not None: |
| pose_images = pose_images[..., :3] |
| if pose_images.shape[1] != H or pose_images.shape[2] != W: |
| resized_pose_images = common_upscale(pose_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) |
| else: |
| resized_pose_images = pose_images.permute(3, 0, 1, 2) |
| resized_pose_images = resized_pose_images * 2 - 1 |
| if not looping: |
| pose_latents = vae.encode([resized_pose_images.to(device, vae.dtype)], device,tiled=tiled_vae) |
| pose_latents = pose_latents.to(offload_device) |
| |
| if pose_latents.shape[2] < latent_window_size: |
| log.info(f"WanAnimate: Padding pose latents from {pose_latents.shape} to length {latent_window_size}") |
| pad_len = latent_window_size - pose_latents.shape[2] |
| pad = torch.zeros(pose_latents.shape[0], pose_latents.shape[1], pad_len, pose_latents.shape[3], pose_latents.shape[4], device=pose_latents.device, dtype=pose_latents.dtype) |
| pose_latents = torch.cat([pose_latents, pad], dim=2) |
| del resized_pose_images |
| else: |
| resized_pose_images = resized_pose_images.to(offload_device, dtype=vae.dtype) |
|
|
| bg_latents = None |
| if bg_images is not None: |
| if bg_images.shape[1] != H or bg_images.shape[2] != W: |
| resized_bg_images = common_upscale(bg_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) |
| else: |
| resized_bg_images = bg_images.permute(3, 0, 1, 2) |
| resized_bg_images = (resized_bg_images[:3] * 2 - 1) |
|
|
| if not looping: |
| if bg_images is None: |
| resized_bg_images = torch.zeros(3, num_frames - num_refs, H, W, device=device, dtype=vae.dtype) |
| bg_latents = vae.encode([resized_bg_images.to(device, vae.dtype)], device,tiled=tiled_vae)[0].to(offload_device) |
| del resized_bg_images |
| elif bg_images is not None: |
| resized_bg_images = resized_bg_images.to(offload_device, dtype=vae.dtype) |
|
|
| if ref_images is not None: |
| if ref_images.shape[1] != H or ref_images.shape[2] != W: |
| resized_ref_images = common_upscale(ref_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) |
| else: |
| resized_ref_images = ref_images.permute(3, 0, 1, 2) |
| resized_ref_images = resized_ref_images[:3] * 2 - 1 |
|
|
| ref_latent = vae.encode([resized_ref_images.to(device, vae.dtype)], device,tiled=tiled_vae)[0] |
| msk = torch.zeros(4, 1, lat_h, lat_w, device=device, dtype=vae.dtype) |
| msk[:, :num_refs] = 1 |
| ref_latent_masked = torch.cat([msk, ref_latent], dim=0).to(offload_device) |
|
|
| if mask is None: |
| bg_mask = torch.zeros(1, num_frames, lat_h, lat_w, device=offload_device, dtype=vae.dtype) |
| else: |
| bg_mask = 1 - mask[:num_frames] |
| if bg_mask.shape[0] < num_frames and not looping: |
| bg_mask = torch.cat([bg_mask, bg_mask[-1:].repeat(num_frames - bg_mask.shape[0], 1, 1)], dim=0) |
| bg_mask = common_upscale(bg_mask.unsqueeze(1), lat_w, lat_h, "nearest", "disabled").squeeze(1) |
| bg_mask = bg_mask.unsqueeze(-1).permute(3, 0, 1, 2).to(offload_device, vae.dtype) |
| |
| if bg_images is None and looping: |
| bg_mask[:, :num_refs] = 1 |
| bg_mask_mask_repeated = torch.repeat_interleave(bg_mask[:, 0:1], repeats=4, dim=1) |
| bg_mask = torch.cat([bg_mask_mask_repeated, bg_mask[:, 1:]], dim=1) |
| bg_mask = bg_mask.view(1, bg_mask.shape[1] // 4, 4, lat_h, lat_w) |
| bg_mask = bg_mask.movedim(1, 2)[0] |
|
|
| if not looping: |
| bg_latents_masked = torch.cat([bg_mask[:, :bg_latents.shape[1]], bg_latents], dim=0) |
| del bg_mask, bg_latents |
| ref_latent = torch.cat([ref_latent_masked, bg_latents_masked], dim=1) |
| else: |
| ref_latent = ref_latent_masked |
|
|
| if face_images is not None: |
| face_images = face_images[..., :3] |
| if face_images.shape[1] != 512 or face_images.shape[2] != 512: |
| resized_face_images = common_upscale(face_images.movedim(-1, 1), 512, 512, "lanczos", "center").movedim(0, 1) |
| else: |
| resized_face_images = face_images.permute(3, 0, 1, 2) |
| resized_face_images = (resized_face_images * 2 - 1).unsqueeze(0) |
| resized_face_images = resized_face_images.to(offload_device, dtype=vae.dtype) |
|
|
|
|
| seq_len = math.ceil((target_shape[2] * target_shape[3]) / 4 * target_shape[1]) |
| |
| if force_offload: |
| vae.model.to(offload_device) |
| mm.soft_empty_cache() |
| gc.collect() |
|
|
| image_embeds = { |
| "clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None, |
| "negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None, |
| "max_seq_len": seq_len, |
| "pose_latents": pose_latents, |
| "pose_images": resized_pose_images if pose_images is not None and looping else None, |
| "bg_images": resized_bg_images if bg_images is not None and looping else None, |
| "ref_masks": bg_mask if mask is not None and looping else None, |
| "is_masked": mask is not None, |
| "ref_latent": ref_latent, |
| "ref_image": resized_ref_images if ref_images is not None else None, |
| "face_pixels": resized_face_images if face_images is not None else None, |
| "num_frames": num_frames, |
| "target_shape": target_shape, |
| "frame_window_size": frame_window_size, |
| "lat_h": lat_h, |
| "lat_w": lat_w, |
| "vae": vae, |
| "colormatch": colormatch, |
| "looping": looping, |
| "pose_strength": pose_strength, |
| "face_strength": face_strength, |
| } |
|
|
| return (image_embeds,) |
| |
| class WanVideoEmptyEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), |
| }, |
| "optional": { |
| "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "control signal for the Fun -model"}), |
| "extra_latents": ("LATENT", {"tooltip": "First latent to use for the Pusa -model"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, num_frames, width, height, control_embeds=None, extra_latents=None): |
| target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, |
| height // VAE_STRIDE[1], |
| width // VAE_STRIDE[2]) |
| |
| embeds = { |
| "target_shape": target_shape, |
| "num_frames": num_frames, |
| "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, |
| } |
| if extra_latents is not None: |
| embeds["extra_latents"] = [{ |
| "samples": extra_latents["samples"], |
| "index": 0, |
| }] |
|
|
| return (embeds,) |
| |
| class WanVideoAddExtraLatent: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "embeds": ("WANVIDIMAGE_EMBEDS",), |
| "extra_latents": ("LATENT",), |
| "latent_index": ("INT", {"default": 0, "min": -1000, "max": 1000, "step": 1, "tooltip": "Index to insert the extra latents at in latent space"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "add" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def add(self, embeds, extra_latents, latent_index): |
| |
| new_entry = { |
| "samples": extra_latents["samples"], |
| "index": latent_index, |
| } |
| |
| prev_extra_latents = embeds.get("extra_latents", None) |
| if prev_extra_latents is None: |
| extra_latents_list = [new_entry] |
| elif isinstance(prev_extra_latents, list): |
| extra_latents_list = prev_extra_latents + [new_entry] |
| else: |
| extra_latents_list = [prev_extra_latents, new_entry] |
|
|
| |
| updated = dict(embeds) |
| updated["extra_latents"] = extra_latents_list |
| return (updated,) |
| |
| class WanVideoAddLucyEditLatents: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "embeds": ("WANVIDIMAGE_EMBEDS",), |
| "extra_latents": ("LATENT",), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "add" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def add(self, embeds, extra_latents): |
| updated = dict(embeds) |
| updated["extra_channel_latents"] = extra_latents["samples"] |
| return (updated,) |
|
|
| class WanVideoMiniMaxRemoverEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), |
| "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), |
| "mask_latents": ("LATENT", {"tooltip": "Encoded latents to use as mask"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, num_frames, width, height, latents, mask_latents): |
| target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, |
| height // VAE_STRIDE[1], |
| width // VAE_STRIDE[2]) |
| |
| embeds = { |
| "target_shape": target_shape, |
| "num_frames": num_frames, |
| "minimax_latents": latents["samples"].squeeze(0), |
| "minimax_mask_latents": mask_latents["samples"].squeeze(0), |
| } |
| |
| return (embeds,) |
| |
| |
| class WanVideoPhantomEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), |
| "phantom_latent_1": ("LATENT", {"tooltip": "reference latents for the phantom model"}), |
| |
| "phantom_cfg_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "CFG scale for the extra phantom cond pass"}), |
| "phantom_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the phantom model"}), |
| "phantom_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the phantom model"}), |
| }, |
| "optional": { |
| "phantom_latent_2": ("LATENT", {"tooltip": "reference latents for the phantom model"}), |
| "phantom_latent_3": ("LATENT", {"tooltip": "reference latents for the phantom model"}), |
| "phantom_latent_4": ("LATENT", {"tooltip": "reference latents for the phantom model"}), |
| "vace_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "VACE embeds"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, num_frames, phantom_cfg_scale, phantom_start_percent, phantom_end_percent, phantom_latent_1, phantom_latent_2=None, phantom_latent_3=None, phantom_latent_4=None, vace_embeds=None): |
| samples = phantom_latent_1["samples"].squeeze(0) |
| if phantom_latent_2 is not None: |
| samples = torch.cat([samples, phantom_latent_2["samples"].squeeze(0)], dim=1) |
| if phantom_latent_3 is not None: |
| samples = torch.cat([samples, phantom_latent_3["samples"].squeeze(0)], dim=1) |
| if phantom_latent_4 is not None: |
| samples = torch.cat([samples, phantom_latent_4["samples"].squeeze(0)], dim=1) |
| C, T, H, W = samples.shape |
|
|
| log.info(f"Phantom latents shape: {samples.shape}") |
|
|
| target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, |
| H * 8 // VAE_STRIDE[1], |
| W * 8 // VAE_STRIDE[2]) |
| |
| embeds = { |
| "target_shape": target_shape, |
| "num_frames": num_frames, |
| "phantom_latents": samples, |
| "phantom_cfg_scale": phantom_cfg_scale, |
| "phantom_start_percent": phantom_start_percent, |
| "phantom_end_percent": phantom_end_percent, |
| } |
| if vace_embeds is not None: |
| vace_input = { |
| "vace_context": vace_embeds["vace_context"], |
| "vace_scale": vace_embeds["vace_scale"], |
| "has_ref": vace_embeds["has_ref"], |
| "vace_start_percent": vace_embeds["vace_start_percent"], |
| "vace_end_percent": vace_embeds["vace_end_percent"], |
| "vace_seq_len": vace_embeds["vace_seq_len"], |
| "additional_vace_inputs": vace_embeds["additional_vace_inputs"], |
| } |
| embeds.update(vace_input) |
| |
| return (embeds,) |
| |
| class WanVideoControlEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), |
| "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), |
| }, |
| "optional": { |
| "fun_ref_image": ("LATENT", {"tooltip": "Reference latent for the Fun 1.1 -model"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, latents, start_percent, end_percent, fun_ref_image=None): |
| samples = latents["samples"].squeeze(0) |
| C, T, H, W = samples.shape |
|
|
| num_frames = (T - 1) * 4 + 1 |
| seq_len = math.ceil((H * W) / 4 * ((num_frames - 1) // 4 + 1)) |
| |
| embeds = { |
| "max_seq_len": seq_len, |
| "target_shape": samples.shape, |
| "num_frames": num_frames, |
| "control_embeds": { |
| "control_images": samples, |
| "start_percent": start_percent, |
| "end_percent": end_percent, |
| "fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None, |
| } |
| } |
| |
| return (embeds,) |
| |
| class WanVideoAddControlEmbeds: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "embeds": ("WANVIDIMAGE_EMBEDS",), |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), |
| }, |
| "optional": { |
| "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), |
| "fun_ref_image": ("LATENT", {"tooltip": "Reference latent for the Fun 1.1 -model"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, embeds, start_percent, end_percent, fun_ref_image=None, latents=None): |
| new_entry = { |
| "control_images": latents["samples"].squeeze(0) if latents is not None else None, |
| "start_percent": start_percent, |
| "end_percent": end_percent, |
| "fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None, |
| } |
|
|
| updated = dict(embeds) |
| updated["control_embeds"] = new_entry |
|
|
| return (updated,) |
| |
| class WanVideoAddPusaNoise: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "embeds": ("WANVIDIMAGE_EMBEDS",), |
| "noise_multipliers": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Noise multipliers for Pusa, can be a list of floats"}), |
| "noisy_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "tooltip": "Number steps to apply the extra noise"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) |
| RETURN_NAMES = ("image_embeds",) |
| FUNCTION = "add" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Adds latent and timestep noise multipliers when using flowmatch_pusa" |
|
|
| def add(self, embeds, noise_multipliers, noisy_steps): |
| updated = dict(embeds) |
| updated["pusa_noise_multipliers"] = noise_multipliers |
| updated["pusa_noisy_steps"] = noisy_steps |
|
|
| return (updated,) |
| |
| class WanVideoSLG: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "blocks": ("STRING", {"default": "10", "tooltip": "Blocks to skip uncond on, separated by comma, index starts from 0"}), |
| "start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("SLGARGS", ) |
| RETURN_NAMES = ("slg_args",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Skips uncond on the selected blocks" |
|
|
| def process(self, blocks, start_percent, end_percent): |
| slg_block_list = [int(x.strip()) for x in blocks.split(",")] |
|
|
| slg_args = { |
| "blocks": slg_block_list, |
| "start_percent": start_percent, |
| "end_percent": end_percent, |
| } |
| return (slg_args,) |
|
|
| |
| class WanVideoVACEEncode: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "vae": ("WANVAE",), |
| "width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}), |
| "height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}), |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), |
| "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), |
| "vace_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply VACE"}), |
| "vace_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply VACE"}), |
| }, |
| "optional": { |
| "input_frames": ("IMAGE",), |
| "ref_images": ("IMAGE",), |
| "input_masks": ("MASK",), |
| "prev_vace_embeds": ("WANVIDIMAGE_EMBEDS",), |
| "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) |
| RETURN_NAMES = ("vace_embeds",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def process(self, vae, width, height, num_frames, strength, vace_start_percent, vace_end_percent, input_frames=None, ref_images=None, input_masks=None, prev_vace_embeds=None, tiled_vae=False): |
| width = (width // 16) * 16 |
| height = (height // 16) * 16 |
|
|
| target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, |
| height // VAE_STRIDE[1], |
| width // VAE_STRIDE[2]) |
| |
| if input_frames is None: |
| input_frames = torch.zeros((1, 3, num_frames, height, width), device=device, dtype=vae.dtype) |
| else: |
| input_frames = input_frames.clone()[:num_frames, :, :, :3] |
| input_frames = common_upscale(input_frames.movedim(-1, 1), width, height, "lanczos", "disabled").movedim(1, -1) |
| input_frames = input_frames.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) |
| input_frames = input_frames * 2 - 1 |
| if input_masks is None: |
| input_masks = torch.ones_like(input_frames, device=device) |
| else: |
| log.info(f"input_masks shape: {input_masks.shape}") |
| input_masks = input_masks[:num_frames] |
| input_masks = common_upscale(input_masks.clone().unsqueeze(1), width, height, "nearest-exact", "disabled").squeeze(1) |
| input_masks = input_masks.to(vae.dtype).to(device) |
| input_masks = input_masks.unsqueeze(-1).unsqueeze(0).permute(0, 4, 1, 2, 3).repeat(1, 3, 1, 1, 1) |
|
|
| if ref_images is not None: |
| ref_images = ref_images.clone()[..., :3] |
| |
| if ref_images.shape[0] > 1: |
| ref_images = torch.cat([ref_images[i] for i in range(ref_images.shape[0])], dim=1).unsqueeze(0) |
| |
| B, H, W, C = ref_images.shape |
| current_aspect = W / H |
| target_aspect = width / height |
| if current_aspect > target_aspect: |
| |
| new_h = int(W / target_aspect) |
| pad_h = (new_h - H) // 2 |
| padded = torch.ones(ref_images.shape[0], new_h, W, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) |
| padded[:, pad_h:pad_h+H, :, :] = ref_images |
| ref_images = padded |
| elif current_aspect < target_aspect: |
| |
| new_w = int(H * target_aspect) |
| pad_w = (new_w - W) // 2 |
| padded = torch.ones(ref_images.shape[0], H, new_w, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) |
| padded[:, :, pad_w:pad_w+W, :] = ref_images |
| ref_images = padded |
| ref_images = common_upscale(ref_images.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) |
| |
| ref_images = ref_images.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3).unsqueeze(0) |
| ref_images = ref_images * 2 - 1 |
|
|
| vae = vae.to(device) |
| z0 = self.vace_encode_frames(vae, input_frames, ref_images, masks=input_masks, tiled_vae=tiled_vae) |
| |
| m0 = self.vace_encode_masks(input_masks, ref_images) |
| z = self.vace_latent(z0, m0) |
| vae.to(offload_device) |
|
|
| vace_input = { |
| "vace_context": z, |
| "vace_scale": strength, |
| "has_ref": ref_images is not None, |
| "num_frames": num_frames, |
| "target_shape": target_shape, |
| "vace_start_percent": vace_start_percent, |
| "vace_end_percent": vace_end_percent, |
| "vace_seq_len": math.ceil((z[0].shape[2] * z[0].shape[3]) / 4 * z[0].shape[1]), |
| "additional_vace_inputs": [], |
| } |
|
|
| if prev_vace_embeds is not None: |
| if "additional_vace_inputs" in prev_vace_embeds and prev_vace_embeds["additional_vace_inputs"]: |
| vace_input["additional_vace_inputs"] = prev_vace_embeds["additional_vace_inputs"].copy() |
| vace_input["additional_vace_inputs"].append(prev_vace_embeds) |
| |
| return (vace_input,) |
| |
| def vace_encode_frames(self, vae, frames, ref_images, masks=None, tiled_vae=False): |
| if ref_images is None: |
| ref_images = [None] * len(frames) |
| else: |
| assert len(frames) == len(ref_images) |
|
|
| pbar = ProgressBar(len(frames)) |
| if masks is None: |
| latents = vae.encode(frames, device=device, tiled=tiled_vae) |
| else: |
| inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] |
| reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] |
| del frames |
| inactive = vae.encode(inactive, device=device, tiled=tiled_vae) |
| reactive = vae.encode(reactive, device=device, tiled=tiled_vae) |
| latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] |
| del inactive, reactive |
| |
| |
| cat_latents = [] |
| for latent, refs in zip(latents, ref_images): |
| if refs is not None: |
| if masks is None: |
| ref_latent = vae.encode(refs, device=device, tiled=tiled_vae) |
| else: |
| ref_latent = vae.encode(refs, device=device, tiled=tiled_vae) |
| ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] |
| assert all([x.shape[1] == 1 for x in ref_latent]) |
| latent = torch.cat([*ref_latent, latent], dim=1) |
| cat_latents.append(latent) |
| pbar.update(1) |
| return cat_latents |
|
|
| def vace_encode_masks(self, masks, ref_images=None): |
| if ref_images is None: |
| ref_images = [None] * len(masks) |
| else: |
| assert len(masks) == len(ref_images) |
|
|
| result_masks = [] |
| pbar = ProgressBar(len(masks)) |
| for mask, refs in zip(masks, ref_images): |
| _c, depth, height, width = mask.shape |
| new_depth = int((depth + 3) // VAE_STRIDE[0]) |
| height = 2 * (int(height) // (VAE_STRIDE[1] * 2)) |
| width = 2 * (int(width) // (VAE_STRIDE[2] * 2)) |
|
|
| |
| mask = mask[0, :, :, :] |
| mask = mask.view( |
| depth, height, VAE_STRIDE[1], width, VAE_STRIDE[1] |
| ) |
| mask = mask.permute(2, 4, 0, 1, 3) |
| mask = mask.reshape( |
| VAE_STRIDE[1] * VAE_STRIDE[2], depth, height, width |
| ) |
|
|
| |
| mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) |
|
|
| if refs is not None: |
| length = len(refs) |
| mask_pad = torch.zeros_like(mask[:, :length, :, :]) |
| mask = torch.cat((mask_pad, mask), dim=1) |
| result_masks.append(mask) |
| pbar.update(1) |
| return result_masks |
|
|
| def vace_latent(self, z, m): |
| return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] |
|
|
|
|
| |
| class WanVideoContextOptions: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],), |
| "context_frames": ("INT", {"default": 81, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), |
| "context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), |
| "context_overlap": ("INT", {"default": 16, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), |
| "freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}), |
| "verbose": ("BOOLEAN", {"default": False, "tooltip": "Print debug output"}), |
| }, |
| "optional": { |
| "fuse_method": (["linear", "pyramid"], {"default": "linear", "tooltip": "Window weight function: linear=ramps at edges only, pyramid=triangular weights peaking in middle"}), |
| "reference_latent": ("LATENT", {"tooltip": "Image to be used as init for I2V models for windows where first frame is not the actual first frame. Mostly useful with MAGREF model"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDCONTEXT", ) |
| RETURN_NAMES = ("context_options",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Context options for WanVideo, allows splitting the video into context windows and attemps blending them for longer generations than the model and memory otherwise would allow." |
|
|
| def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise, verbose, image_cond_start_step=6, image_cond_window_count=2, vae=None, fuse_method="linear", reference_latent=None): |
| context_options = { |
| "context_schedule":context_schedule, |
| "context_frames":context_frames, |
| "context_stride":context_stride, |
| "context_overlap":context_overlap, |
| "freenoise":freenoise, |
| "verbose":verbose, |
| "fuse_method":fuse_method, |
| "reference_latent":reference_latent["samples"] if reference_latent is not None else None, |
| } |
|
|
| return (context_options,) |
| |
| |
| class WanVideoFlowEdit: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "source_embeds": ("WANVIDEOTEXTEMBEDS", ), |
| "skip_steps": ("INT", {"default": 4, "min": 0}), |
| "drift_steps": ("INT", {"default": 0, "min": 0}), |
| "drift_flow_shift": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 30.0, "step": 0.01}), |
| "source_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), |
| "drift_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), |
| }, |
| "optional": { |
| "source_image_embeds": ("WANVIDIMAGE_EMBEDS", ), |
| } |
| } |
|
|
| RETURN_TYPES = ("FLOWEDITARGS", ) |
| RETURN_NAMES = ("flowedit_args",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Flowedit options for WanVideo" |
|
|
| def process(self, **kwargs): |
| return (kwargs,) |
| |
| class WanVideoLoopArgs: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "shift_skip": ("INT", {"default": 6, "min": 0, "tooltip": "Skip step of latent shift"}), |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the looping effect"}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the looping effect"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("LOOPARGS", ) |
| RETURN_NAMES = ("loop_args",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Looping through latent shift as shown in https://github.com/YisuiTT/Mobius/" |
|
|
| def process(self, **kwargs): |
| return (kwargs,) |
|
|
| class WanVideoExperimentalArgs: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "video_attention_split_steps": ("STRING", {"default": "", "tooltip": "Steps to split self attention when using multiple prompts"}), |
| "cfg_zero_star": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WeichenFan/CFG-Zero-star"}), |
| "use_zero_init": ("BOOLEAN", {"default": False}), |
| "zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "Steps to split self attention when using multiple prompts"}), |
| "use_fresca": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WikiChao/FreSca"}), |
| "fresca_scale_low": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), |
| "fresca_scale_high": ("FLOAT", {"default": 1.25, "min": 0.0, "max": 10.0, "step": 0.01}), |
| "fresca_freq_cutoff": ("INT", {"default": 20, "min": 0, "max": 10000, "step": 1}), |
| "use_tcfg": ("BOOLEAN", {"default": False, "tooltip": "https://arxiv.org/abs/2503.18137 TCFG: Tangential Damping Classifier-free Guidance. CFG artifacts reduction."}), |
| "raag_alpha": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Alpha value for RAAG, 1.0 is default, 0.0 is disabled."}), |
| "bidirectional_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable bidirectional sampling, based on https://github.com/ff2416/WanFM"}), |
| "temporal_score_rescaling": ("BOOLEAN", {"default": False, "tooltip": "Enable temporal score rescaling: https://github.com/temporalscorerescaling/TSR/"}), |
| "tsr_k": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "The sampling temperature"}), |
| "tsr_sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "How early TSR steer the sampling process"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("EXPERIMENTALARGS", ) |
| RETURN_NAMES = ("exp_args",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Experimental stuff" |
| EXPERIMENTAL = True |
|
|
| def process(self, **kwargs): |
| return (kwargs,) |
| |
| class WanVideoFreeInitArgs: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "freeinit_num_iters": ("INT", {"default": 3, "min": 1, "max": 10, "tooltip": "Number of FreeInit iterations"}), |
| "freeinit_method": (["butterworth", "ideal", "gaussian", "none"], {"default": "ideal", "tooltip": "Frequency filter type"}), |
| "freeinit_n": ("INT", {"default": 4, "min": 1, "max": 10, "tooltip": "Butterworth filter order (only for butterworth)"}), |
| "freeinit_d_s": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Spatial filter cutoff"}), |
| "freeinit_d_t": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Temporal filter cutoff"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("FREEINITARGS", ) |
| RETURN_NAMES = ("freeinit_args",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "https://github.com/TianxingWu/FreeInit; FreeInit, a concise yet effective method to improve temporal consistency of videos generated by diffusion models" |
| EXPERIMENTAL = True |
|
|
| def process(self, **kwargs): |
| return (kwargs,) |
| |
| class WanVideoScheduler: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "scheduler": (scheduler_list, {"default": "unipc"}), |
| "steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}), |
| "shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}), |
| "start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}), |
| "end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"}) |
| }, |
| "optional": { |
| "sigmas": ("SIGMAS", ), |
| }, |
| "hidden": { |
| "unique_id": "UNIQUE_ID", |
| }, |
| } |
|
|
| RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",) |
| RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step") |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| EXPERIMENTAL = True |
|
|
| def process(self, scheduler, steps, start_step, end_step, shift, unique_id, sigmas=None): |
| sample_scheduler, timesteps, start_idx, end_idx = get_scheduler( |
| scheduler, |
| steps, |
| start_step, end_step, shift, |
| device, |
| sigmas=sigmas, |
| log_timesteps=True) |
| |
| scheduler_dict = { |
| "sample_scheduler": sample_scheduler, |
| "timesteps": timesteps, |
| } |
|
|
| try: |
| from server import PromptServer |
| import io |
| import base64 |
| import matplotlib.pyplot as plt |
| except: |
| PromptServer = None |
| if unique_id and PromptServer is not None: |
| try: |
| |
| sigmas_np = sample_scheduler.full_sigmas.cpu().numpy() |
| if not np.isclose(sigmas_np[-1], 0.0, atol=1e-6): |
| sigmas_np = np.append(sigmas_np, 0.0) |
| buf = io.BytesIO() |
| fig = plt.figure(facecolor='#353535') |
| ax = fig.add_subplot(111) |
| ax.set_facecolor('#353535') |
| x_values = range(0, len(sigmas_np)) |
| ax.plot(x_values, sigmas_np) |
| |
| ax.scatter(x_values, sigmas_np, color='white', s=20, zorder=3) |
| for x, y in zip(x_values, sigmas_np): |
| if len(sigmas_np) <= 10: |
| ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points", xytext=(10, 1), ha='center', color='orange', fontsize=12) |
| ax.set_xticks(x_values) |
| ax.set_title("Sigmas", color='white') |
| ax.set_xlabel("Step", color='white') |
| ax.set_ylabel("Sigma Value", color='white') |
| ax.tick_params(axis='x', colors='white', labelsize=10) |
| ax.tick_params(axis='y', colors='white', labelsize=10) |
| |
| end_idx += 1 |
| if end_idx != -1 and 0 <= end_idx < len(sigmas_np) - 1: |
| ax.axvline(end_idx, color='red', linestyle='--', linewidth=2, label='end_step split') |
| |
| if start_idx > 0 and 0 <= start_idx < len(sigmas_np): |
| ax.axvline(start_idx, color='green', linestyle='--', linewidth=2, label='start_step split') |
| if (end_idx != -1 and 0 <= end_idx < len(sigmas_np)) or (start_idx > 0 and 0 <= start_idx < len(sigmas_np)): |
| ax.legend() |
| if start_idx < end_idx and 0 <= start_idx < len(sigmas_np) and 0 < end_idx < len(sigmas_np): |
| ax.axvspan(start_idx, end_idx, color='lightblue', alpha=0.1, label='Sampled Range') |
| plt.tight_layout() |
| plt.savefig(buf, format='png') |
| plt.close(fig) |
| buf.seek(0) |
| img_base64 = base64.b64encode(buf.read()).decode('utf-8') |
| buf.close() |
|
|
| |
| html_img = f"<img src='data:image/png;base64,{img_base64}' alt='Sigmas Plot' style='max-width:100%; height:100%; overflow:hidden; display:block;'>" |
| PromptServer.instance.send_progress_text(html_img, unique_id) |
| except Exception as e: |
| print("Failed to send sigmas plot:", e) |
| pass |
|
|
| return (sigmas, steps, shift, scheduler_dict, start_step, end_step) |
| |
| class WanVideoSchedulerSA_ODE: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "use_adaptive_order": ("BOOLEAN", {"default": False, "tooltip": "Use adaptive order"}), |
| "use_velocity_smoothing": ("BOOLEAN", {"default": True, "tooltip": "Use velocity smoothing"}), |
| "convergence_threshold": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Convergence threshold for velocity smoothing"}), |
| "smoothing_factor": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Smoothing factor for velocity smoothing"}), |
| "steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}), |
| "shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}), |
| "start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}), |
| "end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"}) |
| }, |
| "optional": { |
| "sigmas": ("SIGMAS", ), |
| }, |
| } |
|
|
| RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",) |
| RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step") |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| EXPERIMENTAL = True |
|
|
| def process(self, steps, start_step, end_step, shift, use_adaptive_order, use_velocity_smoothing, convergence_threshold, smoothing_factor, sigmas=None): |
| sample_scheduler, timesteps, _, _ = get_scheduler( |
| scheduler="sa_ode_stable/lowstep", |
| steps=steps, |
| start_step=start_step, end_step=end_step, shift=shift, |
| device=device, |
| sigmas=sigmas, |
| log_timesteps=True, |
| use_adaptive_order=use_adaptive_order, |
| use_velocity_smoothing=use_velocity_smoothing, |
| convergence_threshold=convergence_threshold, |
| smoothing_factor=smoothing_factor |
| ) |
| |
| scheduler_dict = { |
| "sample_scheduler": sample_scheduler, |
| "timesteps": timesteps, |
| } |
|
|
| return (sigmas, steps, shift, scheduler_dict, start_step, end_step) |
|
|
| rope_functions = ["default", "comfy", "comfy_chunked"] |
| class WanVideoRoPEFunction: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "rope_function": (rope_functions, {"default": "comfy"}), |
| "ntk_scale_f": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), |
| "ntk_scale_h": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), |
| "ntk_scale_w": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), |
| }, |
| } |
|
|
| RETURN_TYPES = (rope_functions, ) |
| RETURN_NAMES = ("rope_function",) |
| FUNCTION = "process" |
| CATEGORY = "WanVideoWrapper" |
| EXPERIMENTAL = True |
|
|
| def process(self, rope_function, ntk_scale_f, ntk_scale_h, ntk_scale_w): |
| if ntk_scale_f != 1.0 or ntk_scale_h != 1.0 or ntk_scale_w != 1.0: |
| rope_func_dict = { |
| "rope_function": rope_function, |
| "ntk_scale_f": ntk_scale_f, |
| "ntk_scale_h": ntk_scale_h, |
| "ntk_scale_w": ntk_scale_w, |
| } |
| return (rope_func_dict,) |
| return (rope_function,) |
|
|
|
|
| |
| class WanVideoDecode: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "vae": ("WANVAE",), |
| "samples": ("LATENT",), |
| "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": ( |
| "Drastically reduces memory use but will introduce seams at tile stride boundaries. " |
| "The location and number of seams is dictated by the tile stride size. " |
| "The visibility of seams can be controlled by increasing the tile size. " |
| "Seams become less obvious at 1.5x stride and are barely noticeable at 2x stride size. " |
| "Which is to say if you use a stride width of 160, the seams are barely noticeable with a tile width of 320." |
| )}), |
| "tile_x": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile width in pixels. Smaller values use less VRAM but will make seams more obvious."}), |
| "tile_y": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile height in pixels. Smaller values use less VRAM but will make seams more obvious."}), |
| "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride width in pixels. Smaller values use less VRAM but will introduce more seams."}), |
| "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride height in pixels. Smaller values use less VRAM but will introduce more seams."}), |
| }, |
| "optional": { |
| "normalization": (["default", "minmax"], {"advanced": True}), |
| } |
| } |
|
|
| @classmethod |
| def VALIDATE_INPUTS(s, tile_x, tile_y, tile_stride_x, tile_stride_y): |
| if tile_x <= tile_stride_x: |
| return "Tile width must be larger than the tile stride width." |
| if tile_y <= tile_stride_y: |
| return "Tile height must be larger than the tile stride height." |
| return True |
|
|
| RETURN_TYPES = ("IMAGE",) |
| RETURN_NAMES = ("images",) |
| FUNCTION = "decode" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, normalization="default"): |
| mm.soft_empty_cache() |
| video = samples.get("video", None) |
| if video is not None: |
| video.clamp_(-1.0, 1.0) |
| video.add_(1.0).div_(2.0) |
| return video.cpu().float(), |
| latents = samples["samples"] |
| end_image = samples.get("end_image", None) |
| has_ref = samples.get("has_ref", False) |
| drop_last = samples.get("drop_last", False) |
| is_looped = samples.get("looped", False) |
|
|
| vae.to(device) |
|
|
| latents = latents.to(device = device, dtype = vae.dtype) |
|
|
| mm.soft_empty_cache() |
|
|
| if has_ref: |
| latents = latents[:, :, 1:] |
| if drop_last: |
| latents = latents[:, :, :-1] |
|
|
| if type(vae).__name__ == "TAEHV": |
| images = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3) |
| images = torch.clamp(images, 0.0, 1.0) |
| images = images.permute(1, 2, 3, 0).cpu().float() |
| return (images,) |
| else: |
| if end_image is not None: |
| enable_vae_tiling = False |
| images = vae.decode(latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8))[0] |
| |
| |
| images = images.cpu().float() |
|
|
| if normalization == "minmax": |
| images.sub_(images.min()).div_(images.max() - images.min()) |
| else: |
| images.clamp_(-1.0, 1.0) |
| images.add_(1.0).div_(2.0) |
| |
| if is_looped: |
| temp_latents = torch.cat([latents[:, :, -3:]] + [latents[:, :, :2]], dim=2) |
| temp_images = vae.decode(temp_latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor))[0] |
| temp_images = temp_images.cpu().float() |
| temp_images = (temp_images - temp_images.min()) / (temp_images.max() - temp_images.min()) |
| images = torch.cat([temp_images[:, 9:].to(images), images[:, 5:]], dim=1) |
|
|
| if end_image is not None: |
| images = images[:, 0:-1] |
|
|
| |
| vae.to(offload_device) |
| mm.soft_empty_cache() |
|
|
| images.clamp_(0.0, 1.0) |
|
|
| return (images.permute(1, 2, 3, 0),) |
|
|
| |
| class WanVideoEncodeLatentBatch: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "vae": ("WANVAE",), |
| "images": ("IMAGE",), |
| "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), |
| "tile_x": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), |
| "tile_y": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), |
| "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), |
| "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("LATENT",) |
| RETURN_NAMES = ("samples",) |
| FUNCTION = "encode" |
| CATEGORY = "WanVideoWrapper" |
| DESCRIPTION = "Encodes a batch of images individually to create a latent video batch where each video is a single frame, useful for I2V init purposes, for example as multiple context window inits" |
|
|
| def encode(self, vae, images, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, latent_strength=1.0): |
| vae.to(device) |
|
|
| images = images.clone() |
|
|
| B, H, W, C = images.shape |
| if W % 16 != 0 or H % 16 != 0: |
| new_height = (H // 16) * 16 |
| new_width = (W // 16) * 16 |
| log.warning(f"Image size {W}x{H} is not divisible by 16, resizing to {new_width}x{new_height}") |
| images = common_upscale(images.movedim(-1, 1), new_width, new_height, "lanczos", "disabled").movedim(1, -1) |
|
|
| if images.shape[-1] == 4: |
| images = images[..., :3] |
| images = images.to(vae.dtype).to(device) * 2.0 - 1.0 |
|
|
| latent_list = [] |
| for img in images: |
| if enable_vae_tiling and tile_x is not None: |
| latent = vae.encode(img.unsqueeze(0).unsqueeze(0).permute(0, 4, 1, 2, 3), device=device, tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor)) |
| else: |
| latent = vae.encode(img.unsqueeze(0).unsqueeze(0).permute(0, 4, 1, 2, 3), device=device, tiled=enable_vae_tiling) |
| |
| if latent_strength != 1.0: |
| latent *= latent_strength |
| latent_list.append(latent.squeeze(0).cpu()) |
| latents_out = torch.stack(latent_list, dim=0) |
|
|
| log.info(f"WanVideoEncode: Encoded latents shape {latents_out.shape}") |
| vae.to(offload_device) |
| mm.soft_empty_cache() |
|
|
| return ({"samples": latents_out},) |
|
|
| class WanVideoEncode: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "vae": ("WANVAE",), |
| "image": ("IMAGE",), |
| "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), |
| "tile_x": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), |
| "tile_y": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), |
| "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), |
| "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), |
| }, |
| "optional": { |
| "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for leapfusion I2V where some noise can add motion and give sharper results"}), |
| "latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for leapfusion I2V where lower values allow for more motion"}), |
| "mask": ("MASK", ), |
| } |
| } |
|
|
| RETURN_TYPES = ("LATENT",) |
| RETURN_NAMES = ("samples",) |
| FUNCTION = "encode" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def encode(self, vae, image, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, noise_aug_strength=0.0, latent_strength=1.0, mask=None): |
| vae.to(device) |
|
|
| image = image.clone() |
|
|
| B, H, W, C = image.shape |
| if W % 16 != 0 or H % 16 != 0: |
| new_height = (H // 16) * 16 |
| new_width = (W // 16) * 16 |
| log.warning(f"Image size {W}x{H} is not divisible by 16, resizing to {new_width}x{new_height}") |
| image = common_upscale(image.movedim(-1, 1), new_width, new_height, "lanczos", "disabled").movedim(1, -1) |
|
|
| if image.shape[-1] == 4: |
| image = image[..., :3] |
| image = image.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) |
|
|
| if noise_aug_strength > 0.0: |
| image = add_noise_to_reference_video(image, ratio=noise_aug_strength) |
|
|
| if isinstance(vae, TAEHV): |
| latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False) |
| latents = latents.permute(0, 2, 1, 3, 4) |
| else: |
| latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x//vae.upsampling_factor, tile_y//vae.upsampling_factor), tile_stride=(tile_stride_x//vae.upsampling_factor, tile_stride_y//vae.upsampling_factor)) |
| |
| vae.to(offload_device) |
| if latent_strength != 1.0: |
| latents *= latent_strength |
|
|
| log.info(f"WanVideoEncode: Encoded latents shape {latents.shape}") |
| mm.soft_empty_cache() |
| |
| return ({"samples": latents, "noise_mask": mask},) |
|
|
| NODE_CLASS_MAPPINGS = { |
| "WanVideoDecode": WanVideoDecode, |
| "WanVideoTextEncode": WanVideoTextEncode, |
| "WanVideoTextEncodeSingle": WanVideoTextEncodeSingle, |
| "WanVideoClipVisionEncode": WanVideoClipVisionEncode, |
| "WanVideoImageToVideoEncode": WanVideoImageToVideoEncode, |
| "WanVideoEncode": WanVideoEncode, |
| "WanVideoEncodeLatentBatch": WanVideoEncodeLatentBatch, |
| "WanVideoEmptyEmbeds": WanVideoEmptyEmbeds, |
| "WanVideoEnhanceAVideo": WanVideoEnhanceAVideo, |
| "WanVideoContextOptions": WanVideoContextOptions, |
| "WanVideoTextEmbedBridge": WanVideoTextEmbedBridge, |
| "WanVideoFlowEdit": WanVideoFlowEdit, |
| "WanVideoControlEmbeds": WanVideoControlEmbeds, |
| "WanVideoSLG": WanVideoSLG, |
| "WanVideoLoopArgs": WanVideoLoopArgs, |
| "WanVideoSetBlockSwap": WanVideoSetBlockSwap, |
| "WanVideoExperimentalArgs": WanVideoExperimentalArgs, |
| "WanVideoVACEEncode": WanVideoVACEEncode, |
| "WanVideoPhantomEmbeds": WanVideoPhantomEmbeds, |
| "WanVideoRealisDanceLatents": WanVideoRealisDanceLatents, |
| "WanVideoApplyNAG": WanVideoApplyNAG, |
| "WanVideoMiniMaxRemoverEmbeds": WanVideoMiniMaxRemoverEmbeds, |
| "WanVideoFreeInitArgs": WanVideoFreeInitArgs, |
| "WanVideoSetRadialAttention": WanVideoSetRadialAttention, |
| "WanVideoBlockList": WanVideoBlockList, |
| "WanVideoTextEncodeCached": WanVideoTextEncodeCached, |
| "WanVideoAddExtraLatent": WanVideoAddExtraLatent, |
| "WanVideoScheduler": WanVideoScheduler, |
| "WanVideoAddStandInLatent": WanVideoAddStandInLatent, |
| "WanVideoAddControlEmbeds": WanVideoAddControlEmbeds, |
| "WanVideoAddMTVMotion": WanVideoAddMTVMotion, |
| "WanVideoRoPEFunction": WanVideoRoPEFunction, |
| "WanVideoAddPusaNoise": WanVideoAddPusaNoise, |
| "WanVideoAnimateEmbeds": WanVideoAnimateEmbeds, |
| "WanVideoAddLucyEditLatents": WanVideoAddLucyEditLatents, |
| "WanVideoSchedulerSA_ODE": WanVideoSchedulerSA_ODE, |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "WanVideoDecode": "WanVideo Decode", |
| "WanVideoTextEncode": "WanVideo TextEncode", |
| "WanVideoTextEncodeSingle": "WanVideo TextEncodeSingle", |
| "WanVideoTextImageEncode": "WanVideo TextImageEncode (IP2V)", |
| "WanVideoClipVisionEncode": "WanVideo ClipVision Encode", |
| "WanVideoImageToVideoEncode": "WanVideo ImageToVideo Encode", |
| "WanVideoEncode": "WanVideo Encode", |
| "WanVideoEncodeLatentBatch": "WanVideo Encode Latent Batch", |
| "WanVideoEmptyEmbeds": "WanVideo Empty Embeds", |
| "WanVideoEnhanceAVideo": "WanVideo Enhance-A-Video", |
| "WanVideoContextOptions": "WanVideo Context Options", |
| "WanVideoTextEmbedBridge": "WanVideo TextEmbed Bridge", |
| "WanVideoFlowEdit": "WanVideo FlowEdit", |
| "WanVideoControlEmbeds": "WanVideo Control Embeds", |
| "WanVideoSLG": "WanVideo SLG", |
| "WanVideoLoopArgs": "WanVideo Loop Args", |
| "WanVideoSetBlockSwap": "WanVideo Set BlockSwap", |
| "WanVideoExperimentalArgs": "WanVideo Experimental Args", |
| "WanVideoVACEEncode": "WanVideo VACE Encode", |
| "WanVideoPhantomEmbeds": "WanVideo Phantom Embeds", |
| "WanVideoRealisDanceLatents": "WanVideo RealisDance Latents", |
| "WanVideoApplyNAG": "WanVideo Apply NAG", |
| "WanVideoMiniMaxRemoverEmbeds": "WanVideo MiniMax Remover Embeds", |
| "WanVideoFreeInitArgs": "WanVideo Free Init Args", |
| "WanVideoSetRadialAttention": "WanVideo Set Radial Attention", |
| "WanVideoBlockList": "WanVideo Block List", |
| "WanVideoTextEncodeCached": "WanVideo TextEncode Cached", |
| "WanVideoAddExtraLatent": "WanVideo Add Extra Latent", |
| "WanVideoAddStandInLatent": "WanVideo Add StandIn Latent", |
| "WanVideoAddControlEmbeds": "WanVideo Add Control Embeds", |
| "WanVideoAddMTVMotion": "WanVideo MTV Crafter Motion", |
| "WanVideoRoPEFunction": "WanVideo RoPE Function", |
| "WanVideoAddPusaNoise": "WanVideo Add Pusa Noise", |
| "WanVideoAnimateEmbeds": "WanVideo Animate Embeds", |
| "WanVideoAddLucyEditLatents": "WanVideo Add LucyEdit Latents", |
| "WanVideoSchedulerSA_ODE": "WanVideo Scheduler SA-ODE", |
| } |
|
|