|
|
|
|
|
|
| import torch
|
|
|
|
|
| class BatchRemoveFirstLast:
|
| """
|
| Takes an IMAGE batch and returns the same batch, except:
|
| - removes the FIRST image (index 0)
|
| - removes the LAST image (index B-1)
|
|
|
| Output = images[1:-1]
|
|
|
| Notes:
|
| - If the batch has fewer than 3 images (B < 3), removing both ends would
|
| produce an empty/invalid batch, so this node returns the original batch.
|
| - If a single image comes in as [H, W, C], it is treated as a batch of 1.
|
| """
|
|
|
| CATEGORY = "image/batch"
|
| FUNCTION = "remove_first_last"
|
|
|
| RETURN_TYPES = ("IMAGE",)
|
| RETURN_NAMES = ("images",)
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {"required": {"images": ("IMAGE",)}}
|
|
|
| def remove_first_last(self, images):
|
| if not isinstance(images, torch.Tensor):
|
|
|
| return (images,)
|
|
|
|
|
| if images.dim() == 3:
|
| images = images.unsqueeze(0)
|
| elif images.dim() != 4:
|
|
|
| return (images,)
|
|
|
| b = int(images.shape[0])
|
|
|
|
|
| if b < 3:
|
| return (images,)
|
|
|
| out = images[1:-1].clone()
|
| return (out,)
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "BatchRemoveFirstLast": BatchRemoveFirstLast,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "BatchRemoveFirstLast": "Batch Remove First + Last",
|
| }
|
|
|