MyCustomNodes / Salia_RifeVFI_Insert.py
saliacoel's picture
Upload Salia_RifeVFI_Insert.py
71aa491 verified
# custom_nodes\comfyui-salia_online\nodes\rife_insert_between.py
#
# Node: Insert RIFE-generated in-between frames between two indices of a batch.
#
# Inputs:
# - batch (IMAGE): input batch of frames
# - start (INT): start index in batch
# - end (INT): end index in batch
# - multiplier (INT): number of in-between frames to INSERT (1 => insert 1 frame)
#
# Internals:
# - Extract batch[start] and batch[end]
# - Make a 2-frame batch and run ComfyUI-Frame-Interpolation's RIFE_VFI lazily
# - Call RIFE with (multiplier + 1) because upstream multiplier is a factor
# (2 frames, factor=2 => 1 middle frame; factor=3 => 2 middle frames; etc.)
# - Remove FIRST and LAST from RIFE output (keep only the in-betweens)
# - Insert in-betweens between start and end in the original batch
#
# Output:
# - IMAGE: new batch with inserted frames
#
# Notes:
# - If end > start+1, frames between (start+1 .. end-1) are REPLACED.
# (This matches "place inside between start and end" as immediate neighbors.)
#
from __future__ import annotations
import os
import sys
import importlib
import threading
from typing import Tuple
import torch
_IMPORT_LOCK = threading.Lock()
_RIFE_CLASS = None # cached class object (import-only cache)
# -----------------------------
# Hardcoded settings (match your lazy node)
# -----------------------------
_HARDCODED_CKPT_NAME = "rife47.pth"
_HARDCODED_CLEAR_CACHE_AFTER_N_FRAMES = 10
_HARDCODED_FAST_MODE = True
_HARDCODED_ENSEMBLE = True
_HARDCODED_SCALE_FACTOR = 1.0
def _lazy_get_rife_class():
"""
Lazily import ComfyUI-Frame-Interpolation's RIFE_VFI class without importing
the whole package at ComfyUI startup.
"""
global _RIFE_CLASS
if _RIFE_CLASS is not None:
return _RIFE_CLASS
with _IMPORT_LOCK:
if _RIFE_CLASS is not None:
return _RIFE_CLASS
# This file lives at:
# ...\custom_nodes\comfyui-salia_online\nodes\rife_insert_between.py
# We want:
# ...\custom_nodes\ComfyUI-Frame-Interpolation
this_dir = os.path.dirname(os.path.abspath(__file__))
custom_nodes_dir = os.path.abspath(os.path.join(this_dir, "..", ".."))
cfi_dir = os.path.join(custom_nodes_dir, "ComfyUI-Frame-Interpolation")
if not os.path.isdir(cfi_dir):
raise FileNotFoundError(
f"Could not find ComfyUI-Frame-Interpolation folder at:\n {cfi_dir}\n"
f"Expected it at:\n {os.path.join(custom_nodes_dir, 'ComfyUI-Frame-Interpolation')}"
)
# Add the extension folder so:
# import vfi_models.rife
# and:
# import vfi_utils
# resolve correctly.
if cfi_dir not in sys.path:
sys.path.insert(0, cfi_dir)
rife_mod = importlib.import_module("vfi_models.rife")
rife_cls = getattr(rife_mod, "RIFE_VFI", None)
if rife_cls is None:
raise ImportError("vfi_models.rife imported, but RIFE_VFI class was not found.")
_RIFE_CLASS = rife_cls
return _RIFE_CLASS
class SALIA_RIFE_INSERT_BETWEEN:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"batch": ("IMAGE",),
"start": ("INT", {"default": 0, "min": 0, "step": 1}),
"end": ("INT", {"default": 1, "min": 0, "step": 1}),
# user multiplier = number of inserted frames
"multiplier": ("INT", {"default": 1, "min": 1, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("IMAGE",)
FUNCTION = "insert"
CATEGORY = "salia_online/VFI"
def insert(self, batch: torch.Tensor, start: int, end: int, multiplier: int) -> Tuple[torch.Tensor]:
if batch is None or not hasattr(batch, "shape"):
raise ValueError("Input 'batch' must be an IMAGE tensor.")
if batch.shape[0] < 2:
raise ValueError(f"Input batch must have at least 2 frames, got {batch.shape[0]}.")
start = int(start)
end = int(end)
multiplier = int(multiplier)
n = int(batch.shape[0])
if not (0 <= start < n) or not (0 <= end < n):
raise ValueError(f"start/end out of range. batch has {n} frames, got start={start}, end={end}.")
if start == end:
raise ValueError("start and end must be different indices.")
if start > end:
raise ValueError(f"start must be < end. Got start={start}, end={end}.")
# Extract the two boundary frames
frame_start = batch[start:start + 1]
frame_end = batch[end:end + 1]
# Make a 2-frame batch for RIFE
frames = torch.cat([frame_start, frame_end], dim=0)
# Upstream RIFE multiplier is a *factor*:
# - 2 frames, factor=2 => output 3 frames => 1 in-between
# We want user multiplier = number of in-betweens,
# so factor = user_multiplier + 1
rife_multiplier = multiplier + 1
# Run RIFE lazily
RIFE_VFI = _lazy_get_rife_class()
rife_node = RIFE_VFI()
(rife_out,) = rife_node.vfi(
ckpt_name=_HARDCODED_CKPT_NAME,
frames=frames,
clear_cache_after_n_frames=_HARDCODED_CLEAR_CACHE_AFTER_N_FRAMES,
multiplier=int(rife_multiplier),
fast_mode=_HARDCODED_FAST_MODE,
ensemble=_HARDCODED_ENSEMBLE,
scale_factor=_HARDCODED_SCALE_FACTOR,
optional_interpolation_states=None,
)
# Keep only the in-between frames: drop first and last
# (If something unexpected happens, this safely yields empty middle.)
middle = rife_out[1:-1] if rife_out.shape[0] >= 2 else rife_out[0:0]
# Optional sanity: ensure we got the expected number of inserted frames
# If it doesn't match, we still proceed with whatever RIFE returned.
# expected = multiplier
# if middle.shape[0] != expected:
# print(f"SALIA_RIFE_INSERT_BETWEEN: expected {expected} middle frames, got {middle.shape[0]}")
# Insert: keep frames up to start, then middle, then from end onward.
# This effectively REPLACES any existing frames between start and end.
before = batch[: start + 1]
after = batch[end:] # includes the end frame
# Match device if needed (usually everything is CPU)
if middle.device != before.device:
middle = middle.to(before.device)
out = torch.cat([before, middle, after], dim=0)
return (out,)
NODE_CLASS_MAPPINGS = {
"SALIA_RIFE_INSERT_BETWEEN": SALIA_RIFE_INSERT_BETWEEN,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SALIA_RIFE_INSERT_BETWEEN": "RIFE Insert Between (Lazy, hardcoded rife47)",
}