| import torch
|
|
|
|
|
| class SlideCropBatch40:
|
| """
|
| Create a 40-frame horizontal sliding crop batch from a single 3025x1024 image.
|
|
|
| Output:
|
| IMAGE tensor with shape [40, 1024, 1536, C]
|
|
|
| Notes:
|
| - The first crop is x = 0..1535 inclusive.
|
| - For a 1536-wide crop taken from a 3025-wide image, the last valid crop must start
|
| at x = 1489 and end at x = 3024 inclusive.
|
| - That means the exact per-frame shift over 40 frames cannot be both constant and integer,
|
| because 1489 / 39 is not an integer.
|
| - This node therefore uses the nearest integer positions that are evenly spaced from
|
| 0 to 1489 inclusive.
|
| """
|
|
|
| CATEGORY = "image/animation"
|
| RETURN_TYPES = ("IMAGE",)
|
| FUNCTION = "make_batch"
|
|
|
| FRAME_COUNT = 40
|
| INPUT_WIDTH = 3025
|
| INPUT_HEIGHT = 1024
|
| CROP_WIDTH = 1536
|
| CROP_HEIGHT = 1024
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "image": ("IMAGE",),
|
| }
|
| }
|
|
|
| @classmethod
|
| def _positions(cls):
|
| intervals = cls.FRAME_COUNT - 1
|
| max_shift = cls.INPUT_WIDTH - cls.CROP_WIDTH
|
|
|
| return [((i * max_shift) + intervals // 2) // intervals for i in range(cls.FRAME_COUNT)]
|
|
|
| def make_batch(self, image: torch.Tensor):
|
| if not isinstance(image, torch.Tensor):
|
| raise TypeError("Expected IMAGE input as a torch.Tensor.")
|
|
|
| if image.ndim != 4:
|
| raise ValueError(f"Expected IMAGE tensor with shape [B,H,W,C], got shape {tuple(image.shape)}")
|
|
|
| batch, height, width, channels = image.shape
|
|
|
| if batch != 1:
|
| raise ValueError(
|
| f"This node expects exactly 1 input image (batch size 1), but got batch size {batch}."
|
| )
|
|
|
| if height != self.INPUT_HEIGHT or width != self.INPUT_WIDTH:
|
| raise ValueError(
|
| f"Expected input resolution {self.INPUT_WIDTH}x{self.INPUT_HEIGHT}, "
|
| f"but got {width}x{height}."
|
| )
|
|
|
| if channels < 1:
|
| raise ValueError(f"Expected at least 1 channel, got {channels}.")
|
|
|
| single = image[0]
|
| crops = []
|
|
|
| for x in self._positions():
|
| crop = single[:, x:x + self.CROP_WIDTH, :]
|
| if crop.shape[1] != self.CROP_WIDTH or crop.shape[0] != self.CROP_HEIGHT:
|
| raise RuntimeError(
|
| f"Invalid crop at x={x}: got shape {tuple(crop.shape)}; "
|
| f"expected [{self.CROP_HEIGHT}, {self.CROP_WIDTH}, C]."
|
| )
|
| crops.append(crop)
|
|
|
| output = torch.stack(crops, dim=0)
|
| return (output,)
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "SlideCropBatch40": SlideCropBatch40,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "SlideCropBatch40": "Slide Crop Batch 40",
|
| } |