from __future__ import annotations import heapq from typing import Dict, List, Tuple import torch class BatchEvenMotionPruner: """ Remove the most redundant interior frame from an IMAGE batch until the requested batch size is reached. Redundancy score for an interior frame i: mean_abs_diff(frame[i], frame[left_neighbor]) + mean_abs_diff(frame[i], frame[right_neighbor]) The frame with the LOWEST score is removed first. The first and last frames are never removed. """ CATEGORY = "image/batch" RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("images",) FUNCTION = "prune" @classmethod def INPUT_TYPES(cls): return { "required": { "images": ("IMAGE", {}), "target_count": ( "INT", { "default": 16, "min": 1, "max": 4096, "step": 1, }, ), } } @staticmethod def _validate_images(images: torch.Tensor) -> torch.Tensor: if not isinstance(images, torch.Tensor): raise TypeError("Expected 'images' to be a torch.Tensor.") # ComfyUI IMAGE is normally [B, H, W, C]. Accept [H, W, C] defensively. if images.ndim == 3: images = images.unsqueeze(0) elif images.ndim != 4: raise ValueError( f"Expected IMAGE tensor with shape [B,H,W,C], got shape {tuple(images.shape)}." ) return images @staticmethod def _pair_key(a: int, b: int) -> Tuple[int, int]: return (a, b) if a < b else (b, a) def _pair_difference( self, images: torch.Tensor, left_idx: int, right_idx: int, cache: Dict[Tuple[int, int], float], ) -> float: key = self._pair_key(left_idx, right_idx) cached = cache.get(key) if cached is not None: return cached left = images[left_idx].float() right = images[right_idx].float() # Mean Absolute Difference over all pixels/channels. value = torch.mean(torch.abs(left - right)).item() cache[key] = value return value def _candidate_score( self, images: torch.Tensor, idx: int, prev_idx: List[int], next_idx: List[int], cache: Dict[Tuple[int, int], float], ) -> float: left = prev_idx[idx] right = next_idx[idx] if left == -1 or right == -1: raise ValueError("Endpoints must not be scored for removal.") return ( self._pair_difference(images, left, idx, cache) + self._pair_difference(images, idx, right, cache) ) def prune(self, images: torch.Tensor, target_count: int): images = self._validate_images(images) batch_size = int(images.shape[0]) target_count = int(target_count) if batch_size <= 1 or target_count >= batch_size: return (images,) # If first and last are protected, batches with 2+ frames cannot go below 2. minimum_reachable = 1 if batch_size <= 1 else 2 desired_count = max(target_count, minimum_reachable) if desired_count >= batch_size: return (images,) prev_idx = [-1] + [i - 1 for i in range(1, batch_size)] next_idx = [i + 1 for i in range(batch_size - 1)] + [-1] alive = [True] * batch_size candidate_version = [0] * batch_size pair_cache: Dict[Tuple[int, int], float] = {} heap: List[Tuple[float, int, int]] = [] def push_candidate(i: int) -> None: if i <= 0 or i >= batch_size - 1: return if not alive[i]: return if prev_idx[i] == -1 or next_idx[i] == -1: return candidate_version[i] += 1 score = self._candidate_score(images, i, prev_idx, next_idx, pair_cache) heapq.heappush(heap, (score, i, candidate_version[i])) # Seed all removable interior frames. for i in range(1, batch_size - 1): push_candidate(i) remaining = batch_size while remaining > desired_count and heap: _score, idx, version = heapq.heappop(heap) # Ignore stale heap entries. if not alive[idx]: continue if candidate_version[idx] != version: continue if prev_idx[idx] == -1 or next_idx[idx] == -1: continue left = prev_idx[idx] right = next_idx[idx] # Remove idx from the linked list. alive[idx] = False remaining -= 1 next_idx[left] = right prev_idx[right] = left prev_idx[idx] = -1 next_idx[idx] = -1 # Only neighbors around the removed frame need updated scores. push_candidate(left) push_candidate(right) keep_indices = [i for i, is_alive in enumerate(alive) if is_alive] keep_tensor = torch.tensor(keep_indices, device=images.device, dtype=torch.long) output = images.index_select(0, keep_tensor) return (output,) NODE_CLASS_MAPPINGS = { "BatchEvenMotionPruner": BatchEvenMotionPruner, } NODE_DISPLAY_NAME_MAPPINGS = { "BatchEvenMotionPruner": "Batch Even Motion Pruner", }