| import torch
|
| from typing import List
|
|
|
|
|
| class BatchFilterKeepFirstLast:
|
| """
|
| Batch filter node (IMAGE -> IMAGE) that always keeps the first and last image.
|
|
|
| Modes (int):
|
| - 0 : passthrough (no changes)
|
| - 10 : keep 1st, 3rd, 5th, ... (drop every 2nd), but always keep last
|
| - <10: keep slightly MORE than mode 10 (adds back frames, evenly distributed)
|
| - >10: keep slightly FEWER than mode 10 (removes extra frames, evenly distributed)
|
|
|
| Notes:
|
| - In ComfyUI, IMAGE is a batch (torch.Tensor) of shape [B, H, W, C]. We only filter B. :contentReference[oaicite:2]{index=2}
|
| - Works with RGBA (C=4) or RGB (C=3) since we do not modify channels.
|
| """
|
|
|
| CATEGORY = "image/batch"
|
| RETURN_TYPES = ("IMAGE",)
|
| RETURN_NAMES = ("images",)
|
| FUNCTION = "filter_batch"
|
|
|
|
|
|
|
|
|
| ADJUST_PER_STEP_FRACTION = 0.05
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "images": ("IMAGE",),
|
| "mode": ("INT", {"default": 10, "min": 0, "max": 20, "step": 1}),
|
| }
|
| }
|
|
|
| def filter_batch(self, images: torch.Tensor, mode: int):
|
| if not isinstance(images, torch.Tensor):
|
| raise TypeError("images must be a torch.Tensor")
|
|
|
| if images.ndim != 4:
|
| raise ValueError(f"Expected images with shape [B,H,W,C], got {tuple(images.shape)}")
|
|
|
| b = int(images.shape[0])
|
| if b <= 1 or mode == 0:
|
| return (images,)
|
|
|
|
|
| keep = list(range(0, b, 2))
|
| if (b - 1) not in keep:
|
| keep.append(b - 1)
|
| keep = sorted(set(keep))
|
|
|
| if mode != 10:
|
| delta = mode - 10
|
| step = max(1, int(round(b * self.ADJUST_PER_STEP_FRACTION)))
|
|
|
| min_keep = 1 if b == 1 else 2
|
|
|
| if delta < 0:
|
|
|
| add_count = min((-delta) * step, b - len(keep))
|
| if add_count > 0:
|
| candidates = [i for i in range(b) if i not in keep and i not in (0, b - 1)]
|
| add_idxs = self._evenly_pick(candidates, add_count)
|
| keep = sorted(set(keep + add_idxs))
|
|
|
| elif delta > 0:
|
|
|
| max_removable = max(0, len(keep) - min_keep)
|
| remove_count = min(delta * step, max_removable)
|
| if remove_count > 0:
|
| removable = [i for i in keep if i not in (0, b - 1)]
|
| remove_idxs = set(self._evenly_pick(removable, remove_count))
|
| keep = [i for i in keep if i not in remove_idxs]
|
| keep = sorted(set(keep))
|
|
|
|
|
| if 0 not in keep:
|
| keep.insert(0, 0)
|
| if (b - 1) not in keep:
|
| keep.append(b - 1)
|
| keep = sorted(set(keep))
|
|
|
| out = images[keep, ...]
|
| return (out,)
|
|
|
| @staticmethod
|
| def _evenly_pick(items: List[int], k: int) -> List[int]:
|
| """
|
| Pick k unique elements from items, evenly distributed across the list.
|
| Deterministic, preserves ordering of selected indices in 'items'.
|
| """
|
| m = len(items)
|
| if k <= 0 or m == 0:
|
| return []
|
| k = min(k, m)
|
|
|
|
|
|
|
| positions = [int((i + 1) * (m + 1) / (k + 1)) - 1 for i in range(k)]
|
| return [items[p] for p in positions]
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "BatchFilterKeepFirstLast": BatchFilterKeepFirstLast,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "BatchFilterKeepFirstLast": "Batch Filter (Keep First/Last)",
|
| }
|
|
|