|
|
| import torch |
|
|
| class TemporalHintFromPair: |
| """ |
| Concatenate two RGB images (current & previous) along channel dim to produce a 6-channel IMAGE. |
| Works with batched tensors. If previous is None, it falls back to current (no-op for first frame). |
| """ |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "current": ("IMAGE",), |
| "previous": ("IMAGE",), |
| }, |
| "optional": { |
| "clip_to_range": ("BOOLEAN", {"default": True}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| RETURN_NAMES = ("temporal_hint",) |
| FUNCTION = "make_hint" |
| CATEGORY = "Temporal/Utils" |
|
|
| @staticmethod |
| def _ensure_batch(x): |
| if x.dim() == 3: |
| x = x.unsqueeze(0) |
| return x |
|
|
| @staticmethod |
| def _match_batch(a, b): |
| ba, bb = a.shape[0], b.shape[0] |
| if ba == bb: |
| return a, b |
| if ba == 1: |
| a = a.repeat(bb, 1, 1, 1) |
| elif bb == 1: |
| b = b.repeat(ba, 1, 1, 1) |
| else: |
| n = min(ba, bb) |
| a = a[:n] |
| b = b[:n] |
| return a, b |
|
|
| def make_hint(self, current, previous, clip_to_range=True): |
| current = self._ensure_batch(current) |
| previous = self._ensure_batch(previous) |
|
|
| if current.shape[-1] != 3 or previous.shape[-1] != 3: |
| raise ValueError(f"Expected RGB images with 3 channels; got {current.shape} & {previous.shape}") |
|
|
| current, previous = self._match_batch(current, previous) |
|
|
| if current.shape[1:3] != previous.shape[1:3]: |
| previous = torch.nn.functional.interpolate( |
| previous.permute(0,3,1,2), size=(current.shape[1], current.shape[2]), mode="nearest" |
| ).permute(0,2,3,1) |
|
|
| if clip_to_range: |
| current = current.clamp(0.0, 1.0) |
| previous = previous.clamp(0.0, 1.0) |
|
|
| temporal_hint = torch.cat([current, previous], dim=3) |
| return (temporal_hint,) |
|
|
|
|
| NODE_CLASS_MAPPINGS = { |
| "TemporalHintFromPair": TemporalHintFromPair, |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "TemporalHintFromPair": "Temporal Hint From Pair (6ch)", |
| } |
|
|