|
|
|
|
|
|
|
|
| import torch
|
|
|
| REF_W = 768
|
| REF_H = 1344
|
|
|
| _EPS = 1e-6
|
|
|
|
|
| def _as_batched_image(img: torch.Tensor) -> torch.Tensor:
|
| """
|
| ComfyUI IMAGE tensors are typically [B, H, W, C].
|
| Accepts [H, W, C] as a fallback and converts to [1, H, W, C].
|
| """
|
| if not isinstance(img, torch.Tensor):
|
| raise TypeError(f"Expected torch.Tensor, got {type(img)}")
|
|
|
| if img.dim() == 4:
|
| return img
|
| if img.dim() == 3:
|
| return img.unsqueeze(0)
|
|
|
| raise ValueError(f"Expected IMAGE tensor with 3 or 4 dims, got shape {tuple(img.shape)}")
|
|
|
|
|
| def _clamp_top_left(x: int, y: int, size: int, width: int, height: int) -> tuple[int, int]:
|
| """
|
| Clamp (x, y) so that a size x size square fits inside (width, height).
|
| """
|
| x = int(x)
|
| y = int(y)
|
| max_x = max(0, width - size)
|
| max_y = max(0, height - size)
|
| if x < 0:
|
| x = 0
|
| elif x > max_x:
|
| x = max_x
|
| if y < 0:
|
| y = 0
|
| elif y > max_y:
|
| y = max_y
|
| return x, y
|
|
|
|
|
| def _ensure_rgb(img: torch.Tensor) -> torch.Tensor:
|
| """
|
| Accept RGB or RGBA, return RGB (drop alpha if present).
|
| """
|
| img = _as_batched_image(img)
|
| c = img.shape[-1]
|
| if c == 3:
|
| return img
|
| if c == 4:
|
| return img[..., :3]
|
| raise ValueError(f"Expected 3 or 4 channels, got {c} channels")
|
|
|
|
|
| def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
|
| """
|
| Accept RGB or RGBA, return RGBA (add opaque alpha if missing).
|
| """
|
| img = _as_batched_image(img)
|
| c = img.shape[-1]
|
| if c == 4:
|
| return img
|
| if c == 3:
|
| alpha = torch.ones((*img.shape[:-1], 1), device=img.device, dtype=img.dtype)
|
| return torch.cat([img, alpha], dim=-1)
|
| raise ValueError(f"Expected 3 or 4 channels, got {c} channels")
|
|
|
|
|
| def _rect_size_check(rect: torch.Tensor, size: int) -> None:
|
| rect = _as_batched_image(rect)
|
| h = rect.shape[1]
|
| w = rect.shape[2]
|
| if h != size or w != size:
|
| raise ValueError(f"Rect input must be {size}x{size}, got {w}x{h}.")
|
|
|
|
|
| def _white_where_alpha_zero(rgba: torch.Tensor) -> torch.Tensor:
|
| """
|
| Ensures RGB is WHITE where alpha is (near) zero.
|
| This matches the requirement: transparent pixels should be white, not black.
|
| """
|
| rgba = _as_batched_image(rgba)
|
| if rgba.shape[-1] != 4:
|
| raise ValueError("Expected RGBA tensor for _white_where_alpha_zero")
|
|
|
| rgb = rgba[..., :3]
|
| a = rgba[..., 3:4]
|
| white = torch.ones_like(rgb)
|
| rgb = torch.where(a <= _EPS, white, rgb)
|
| return torch.cat([rgb, a], dim=-1)
|
|
|
|
|
| class _CropoutBase:
|
| SIZE = None
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| size = int(cls.SIZE)
|
|
|
| return {
|
| "required": {
|
| "image": ("IMAGE",),
|
| "x": ("INT", {"default": 0, "min": 0, "max": max(0, REF_W - size), "step": 1}),
|
| "y": ("INT", {"default": 0, "min": 0, "max": max(0, REF_H - size), "step": 1}),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE",)
|
| RETURN_NAMES = ("image",)
|
| FUNCTION = "cropout"
|
| CATEGORY = "image/CropoutExpand"
|
|
|
| def cropout(self, image, x, y):
|
| img = _ensure_rgb(image)
|
| b, h, w, _ = img.shape
|
| size = int(self.SIZE)
|
|
|
| x, y = _clamp_top_left(x, y, size, w, h)
|
| patch = img[:, y : y + size, x : x + size, :].contiguous()
|
| return (patch,)
|
|
|
|
|
| class _ExpandBase:
|
| SIZE = None
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| size = int(cls.SIZE)
|
| return {
|
| "required": {
|
| "rect": ("IMAGE",),
|
| "x": ("INT", {"default": 0, "min": 0, "max": max(0, REF_W - size), "step": 1}),
|
| "y": ("INT", {"default": 0, "min": 0, "max": max(0, REF_H - size), "step": 1}),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE",)
|
| RETURN_NAMES = ("image",)
|
| FUNCTION = "expand"
|
| CATEGORY = "image/CropoutExpand"
|
|
|
| def expand(self, rect, x, y):
|
| size = int(self.SIZE)
|
|
|
| rect_rgba = _ensure_rgba(rect)
|
| _rect_size_check(rect_rgba, size)
|
|
|
| rect_rgba = _white_where_alpha_zero(rect_rgba)
|
|
|
|
|
| b = rect_rgba.shape[0]
|
| out = torch.zeros((b, REF_H, REF_W, 4), device=rect_rgba.device, dtype=rect_rgba.dtype)
|
| out[..., :3] = 1.0
|
| out[..., 3] = 0.0
|
|
|
| x, y = _clamp_top_left(x, y, size, REF_W, REF_H)
|
| out[:, y : y + size, x : x + size, :] = rect_rgba
|
| return (out,)
|
|
|
|
|
|
|
|
|
| class Cropout_Big_384(_CropoutBase):
|
| SIZE = 384
|
|
|
|
|
| class Cropout_Mid_192(_CropoutBase):
|
| SIZE = 192
|
|
|
|
|
| class Cropout_Small_96(_CropoutBase):
|
| SIZE = 96
|
|
|
|
|
| class Expand_Big_384(_ExpandBase):
|
| SIZE = 384
|
|
|
|
|
| class Expand_Mid_192(_ExpandBase):
|
| SIZE = 192
|
|
|
|
|
| class Expand_Small_96(_ExpandBase):
|
| SIZE = 96
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "Cropout_Big_384": Cropout_Big_384,
|
| "Cropout_Mid_192": Cropout_Mid_192,
|
| "Cropout_Small_96": Cropout_Small_96,
|
| "Expand_Big_384": Expand_Big_384,
|
| "Expand_Mid_192": Expand_Mid_192,
|
| "Expand_Small_96": Expand_Small_96,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "Cropout_Big_384": "Cropout_Big_384",
|
| "Cropout_Mid_192": "Cropout_Mid_192",
|
| "Cropout_Small_96": "Cropout_Small_96",
|
| "Expand_Big_384": "Expand_Big_384",
|
| "Expand_Mid_192": "Expand_Mid_192",
|
| "Expand_Small_96": "Expand_Small_96",
|
| }
|
|
|