| import torch
|
|
|
| class Batch_Sprite_BBox_Cropper:
|
| """
|
| ComfyUI custom node:
|
| - Takes a batch of RGBA images (or RGB+MASK).
|
| - Alpha clamp: alpha <= (alpha_cutoff / 255) -> 0
|
| - Computes one global bounding box of visible pixels across the entire batch
|
| - Crops every image to the same bbox (spritesheet-safe)
|
| """
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "images": ("IMAGE",),
|
|
|
| "alpha_cutoff": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
|
| "verbose": ("BOOLEAN", {"default": True}),
|
| },
|
|
|
| "optional": {
|
| "mask": ("MASK",),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT", "INT", "INT")
|
| RETURN_NAMES = ("cropped_images", "left", "top", "right", "bottom", "crop_width", "crop_height")
|
| FUNCTION = "process"
|
| CATEGORY = "image/alpha"
|
|
|
| def process(self, images, alpha_cutoff=10, verbose=True, mask=None):
|
| """
|
| images: torch tensor [B, H, W, C] typically float32 in [0,1]
|
| mask: torch tensor [B, H, W] or [H, W] in [0,1] (optional)
|
| """
|
|
|
| if not isinstance(images, torch.Tensor):
|
| raise TypeError("images must be a torch.Tensor")
|
|
|
| if images.dim() != 4:
|
| raise ValueError(f"images must be [B,H,W,C], got shape {tuple(images.shape)}")
|
|
|
| B, H, W, C = images.shape
|
|
|
|
|
| if C == 4:
|
| rgba = images.clone()
|
| elif C == 3:
|
| if mask is None:
|
| raise ValueError(
|
| "Input images are RGB (C=3). Provide a MASK input or pass RGBA (C=4)."
|
| )
|
|
|
| if mask.dim() == 2:
|
| mask_b = mask.unsqueeze(0).expand(B, -1, -1)
|
| elif mask.dim() == 3:
|
| mask_b = mask
|
| else:
|
| raise ValueError(f"mask must be [H,W] or [B,H,W], got shape {tuple(mask.shape)}")
|
|
|
| if mask_b.shape[0] != B or mask_b.shape[1] != H or mask_b.shape[2] != W:
|
| raise ValueError(
|
| f"mask shape {tuple(mask_b.shape)} must match images batch/height/width {(B,H,W)}"
|
| )
|
|
|
| rgba = torch.cat([images, mask_b.unsqueeze(-1)], dim=-1).clone()
|
| else:
|
| raise ValueError(f"Expected images with 3 (RGB) or 4 (RGBA) channels, got C={C}")
|
|
|
|
|
| threshold = float(alpha_cutoff) / 255.0
|
| alpha = rgba[..., 3]
|
| rgba[..., 3] = torch.where(alpha <= threshold, torch.zeros_like(alpha), alpha)
|
|
|
|
|
| visible = rgba[..., 3] > 0
|
| if not torch.any(visible):
|
|
|
| if verbose:
|
| print(
|
| f"[RGBABatchAlphaClampGlobalCrop] No visible pixels after clamp "
|
| f"(alpha_cutoff={alpha_cutoff}). Returning unchanged RGBA."
|
| )
|
| left = 0
|
| top = 0
|
| right = W - 1
|
| bottom = H - 1
|
| crop_w = W
|
| crop_h = H
|
| return (rgba, left, top, right, bottom, crop_w, crop_h)
|
|
|
|
|
| union = torch.any(visible, dim=0)
|
|
|
| ys = torch.any(union, dim=1)
|
| xs = torch.any(union, dim=0)
|
|
|
| y_idx = torch.nonzero(ys, as_tuple=False).squeeze(1)
|
| x_idx = torch.nonzero(xs, as_tuple=False).squeeze(1)
|
|
|
| top = int(y_idx[0].item())
|
| bottom = int(y_idx[-1].item())
|
| left = int(x_idx[0].item())
|
| right = int(x_idx[-1].item())
|
|
|
|
|
| cropped = rgba[:, top:bottom + 1, left:right + 1, :]
|
|
|
| crop_w = right - left + 1
|
| crop_h = bottom - top + 1
|
|
|
| if verbose:
|
| print(
|
| f"[RGBABatchAlphaClampGlobalCrop] alpha_cutoff={alpha_cutoff} "
|
| f"-> rect: left={left}, top={top}, right={right}, bottom={bottom} "
|
| f"(w={crop_w}, h={crop_h}), batch={B}"
|
| )
|
|
|
| return (cropped, left, top, right, bottom, crop_w, crop_h)
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "Batch_Sprite_BBox_Cropper": Batch_Sprite_BBox_Cropper
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "Batch_Sprite_BBox_Cropper": "Batch_Sprite_BBox_Cropper"
|
| }
|
|
|