saliacoel commited on
Commit
cbb0df3
·
verified ·
1 Parent(s): 813a298

Upload BatchEvenMotionPruner.py

Browse files
Files changed (1) hide show
  1. BatchEvenMotionPruner.py +181 -0
BatchEvenMotionPruner.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import heapq
4
+ from typing import Dict, List, Tuple
5
+
6
+ import torch
7
+
8
+
9
+ class BatchEvenMotionPruner:
10
+ """
11
+ Remove the most redundant interior frame from an IMAGE batch until the
12
+ requested batch size is reached.
13
+
14
+ Redundancy score for an interior frame i:
15
+ mean_abs_diff(frame[i], frame[left_neighbor]) +
16
+ mean_abs_diff(frame[i], frame[right_neighbor])
17
+
18
+ The frame with the LOWEST score is removed first.
19
+ The first and last frames are never removed.
20
+ """
21
+
22
+ CATEGORY = "image/batch"
23
+ RETURN_TYPES = ("IMAGE",)
24
+ RETURN_NAMES = ("images",)
25
+ FUNCTION = "prune"
26
+
27
+ @classmethod
28
+ def INPUT_TYPES(cls):
29
+ return {
30
+ "required": {
31
+ "images": ("IMAGE", {}),
32
+ "target_count": (
33
+ "INT",
34
+ {
35
+ "default": 16,
36
+ "min": 1,
37
+ "max": 4096,
38
+ "step": 1,
39
+ },
40
+ ),
41
+ }
42
+ }
43
+
44
+ @staticmethod
45
+ def _validate_images(images: torch.Tensor) -> torch.Tensor:
46
+ if not isinstance(images, torch.Tensor):
47
+ raise TypeError("Expected 'images' to be a torch.Tensor.")
48
+
49
+ # ComfyUI IMAGE is normally [B, H, W, C]. Accept [H, W, C] defensively.
50
+ if images.ndim == 3:
51
+ images = images.unsqueeze(0)
52
+ elif images.ndim != 4:
53
+ raise ValueError(
54
+ f"Expected IMAGE tensor with shape [B,H,W,C], got shape {tuple(images.shape)}."
55
+ )
56
+
57
+ return images
58
+
59
+ @staticmethod
60
+ def _pair_key(a: int, b: int) -> Tuple[int, int]:
61
+ return (a, b) if a < b else (b, a)
62
+
63
+ def _pair_difference(
64
+ self,
65
+ images: torch.Tensor,
66
+ left_idx: int,
67
+ right_idx: int,
68
+ cache: Dict[Tuple[int, int], float],
69
+ ) -> float:
70
+ key = self._pair_key(left_idx, right_idx)
71
+ cached = cache.get(key)
72
+ if cached is not None:
73
+ return cached
74
+
75
+ left = images[left_idx].float()
76
+ right = images[right_idx].float()
77
+
78
+ # Mean Absolute Difference over all pixels/channels.
79
+ value = torch.mean(torch.abs(left - right)).item()
80
+ cache[key] = value
81
+ return value
82
+
83
+ def _candidate_score(
84
+ self,
85
+ images: torch.Tensor,
86
+ idx: int,
87
+ prev_idx: List[int],
88
+ next_idx: List[int],
89
+ cache: Dict[Tuple[int, int], float],
90
+ ) -> float:
91
+ left = prev_idx[idx]
92
+ right = next_idx[idx]
93
+ if left == -1 or right == -1:
94
+ raise ValueError("Endpoints must not be scored for removal.")
95
+
96
+ return (
97
+ self._pair_difference(images, left, idx, cache)
98
+ + self._pair_difference(images, idx, right, cache)
99
+ )
100
+
101
+ def prune(self, images: torch.Tensor, target_count: int):
102
+ images = self._validate_images(images)
103
+
104
+ batch_size = int(images.shape[0])
105
+ target_count = int(target_count)
106
+
107
+ if batch_size <= 1 or target_count >= batch_size:
108
+ return (images,)
109
+
110
+ # If first and last are protected, batches with 2+ frames cannot go below 2.
111
+ minimum_reachable = 1 if batch_size <= 1 else 2
112
+ desired_count = max(target_count, minimum_reachable)
113
+
114
+ if desired_count >= batch_size:
115
+ return (images,)
116
+
117
+ prev_idx = [-1] + [i - 1 for i in range(1, batch_size)]
118
+ next_idx = [i + 1 for i in range(batch_size - 1)] + [-1]
119
+ alive = [True] * batch_size
120
+ candidate_version = [0] * batch_size
121
+ pair_cache: Dict[Tuple[int, int], float] = {}
122
+ heap: List[Tuple[float, int, int]] = []
123
+
124
+ def push_candidate(i: int) -> None:
125
+ if i <= 0 or i >= batch_size - 1:
126
+ return
127
+ if not alive[i]:
128
+ return
129
+ if prev_idx[i] == -1 or next_idx[i] == -1:
130
+ return
131
+
132
+ candidate_version[i] += 1
133
+ score = self._candidate_score(images, i, prev_idx, next_idx, pair_cache)
134
+ heapq.heappush(heap, (score, i, candidate_version[i]))
135
+
136
+ # Seed all removable interior frames.
137
+ for i in range(1, batch_size - 1):
138
+ push_candidate(i)
139
+
140
+ remaining = batch_size
141
+
142
+ while remaining > desired_count and heap:
143
+ _score, idx, version = heapq.heappop(heap)
144
+
145
+ # Ignore stale heap entries.
146
+ if not alive[idx]:
147
+ continue
148
+ if candidate_version[idx] != version:
149
+ continue
150
+ if prev_idx[idx] == -1 or next_idx[idx] == -1:
151
+ continue
152
+
153
+ left = prev_idx[idx]
154
+ right = next_idx[idx]
155
+
156
+ # Remove idx from the linked list.
157
+ alive[idx] = False
158
+ remaining -= 1
159
+
160
+ next_idx[left] = right
161
+ prev_idx[right] = left
162
+ prev_idx[idx] = -1
163
+ next_idx[idx] = -1
164
+
165
+ # Only neighbors around the removed frame need updated scores.
166
+ push_candidate(left)
167
+ push_candidate(right)
168
+
169
+ keep_indices = [i for i, is_alive in enumerate(alive) if is_alive]
170
+ keep_tensor = torch.tensor(keep_indices, device=images.device, dtype=torch.long)
171
+ output = images.index_select(0, keep_tensor)
172
+ return (output,)
173
+
174
+
175
+ NODE_CLASS_MAPPINGS = {
176
+ "BatchEvenMotionPruner": BatchEvenMotionPruner,
177
+ }
178
+
179
+ NODE_DISPLAY_NAME_MAPPINGS = {
180
+ "BatchEvenMotionPruner": "Batch Even Motion Pruner",
181
+ }