MyCustomNodes / Slide.py
saliacoel's picture
Upload Slide.py
836f12e verified
import torch
class SlideCropBatch40:
"""
Create a 40-frame horizontal sliding crop batch from a single 3025x1024 image.
Output:
IMAGE tensor with shape [40, 1024, 1536, C]
Notes:
- The first crop is x = 0..1535 inclusive.
- For a 1536-wide crop taken from a 3025-wide image, the last valid crop must start
at x = 1489 and end at x = 3024 inclusive.
- That means the exact per-frame shift over 40 frames cannot be both constant and integer,
because 1489 / 39 is not an integer.
- This node therefore uses the nearest integer positions that are evenly spaced from
0 to 1489 inclusive.
"""
CATEGORY = "image/animation"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "make_batch"
FRAME_COUNT = 40
INPUT_WIDTH = 3025
INPUT_HEIGHT = 1024
CROP_WIDTH = 1536
CROP_HEIGHT = 1024
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
}
}
@classmethod
def _positions(cls):
intervals = cls.FRAME_COUNT - 1
max_shift = cls.INPUT_WIDTH - cls.CROP_WIDTH # 1489
# Integer-only nearest rounding of i * max_shift / intervals.
return [((i * max_shift) + intervals // 2) // intervals for i in range(cls.FRAME_COUNT)]
def make_batch(self, image: torch.Tensor):
if not isinstance(image, torch.Tensor):
raise TypeError("Expected IMAGE input as a torch.Tensor.")
if image.ndim != 4:
raise ValueError(f"Expected IMAGE tensor with shape [B,H,W,C], got shape {tuple(image.shape)}")
batch, height, width, channels = image.shape
if batch != 1:
raise ValueError(
f"This node expects exactly 1 input image (batch size 1), but got batch size {batch}."
)
if height != self.INPUT_HEIGHT or width != self.INPUT_WIDTH:
raise ValueError(
f"Expected input resolution {self.INPUT_WIDTH}x{self.INPUT_HEIGHT}, "
f"but got {width}x{height}."
)
if channels < 1:
raise ValueError(f"Expected at least 1 channel, got {channels}.")
single = image[0] # [H, W, C]
crops = []
for x in self._positions():
crop = single[:, x:x + self.CROP_WIDTH, :]
if crop.shape[1] != self.CROP_WIDTH or crop.shape[0] != self.CROP_HEIGHT:
raise RuntimeError(
f"Invalid crop at x={x}: got shape {tuple(crop.shape)}; "
f"expected [{self.CROP_HEIGHT}, {self.CROP_WIDTH}, C]."
)
crops.append(crop)
output = torch.stack(crops, dim=0) # [40, 1024, 1536, C]
return (output,)
NODE_CLASS_MAPPINGS = {
"SlideCropBatch40": SlideCropBatch40,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SlideCropBatch40": "Slide Crop Batch 40",
}