walkanims / canvas_expand_crop.py
saliacoel's picture
Upload canvas_expand_crop.py
5902a91 verified
# comfy_cropout_expand_nodes.py
# Put this file in: ComfyUI/custom_nodes/
# Restart ComfyUI after adding/updating.
import torch
REF_W = 768
REF_H = 1344
_EPS = 1e-6
def _as_batched_image(img: torch.Tensor) -> torch.Tensor:
"""
ComfyUI IMAGE tensors are typically [B, H, W, C].
Accepts [H, W, C] as a fallback and converts to [1, H, W, C].
"""
if not isinstance(img, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(img)}")
if img.dim() == 4:
return img
if img.dim() == 3:
return img.unsqueeze(0)
raise ValueError(f"Expected IMAGE tensor with 3 or 4 dims, got shape {tuple(img.shape)}")
def _clamp_top_left(x: int, y: int, size: int, width: int, height: int) -> tuple[int, int]:
"""
Clamp (x, y) so that a size x size square fits inside (width, height).
"""
x = int(x)
y = int(y)
max_x = max(0, width - size)
max_y = max(0, height - size)
if x < 0:
x = 0
elif x > max_x:
x = max_x
if y < 0:
y = 0
elif y > max_y:
y = max_y
return x, y
def _ensure_rgb(img: torch.Tensor) -> torch.Tensor:
"""
Accept RGB or RGBA, return RGB (drop alpha if present).
"""
img = _as_batched_image(img)
c = img.shape[-1]
if c == 3:
return img
if c == 4:
return img[..., :3]
raise ValueError(f"Expected 3 or 4 channels, got {c} channels")
def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
"""
Accept RGB or RGBA, return RGBA (add opaque alpha if missing).
"""
img = _as_batched_image(img)
c = img.shape[-1]
if c == 4:
return img
if c == 3:
alpha = torch.ones((*img.shape[:-1], 1), device=img.device, dtype=img.dtype)
return torch.cat([img, alpha], dim=-1)
raise ValueError(f"Expected 3 or 4 channels, got {c} channels")
def _rect_size_check(rect: torch.Tensor, size: int) -> None:
rect = _as_batched_image(rect)
h = rect.shape[1]
w = rect.shape[2]
if h != size or w != size:
raise ValueError(f"Rect input must be {size}x{size}, got {w}x{h}.")
def _white_where_alpha_zero(rgba: torch.Tensor) -> torch.Tensor:
"""
Ensures RGB is WHITE where alpha is (near) zero.
This matches the requirement: transparent pixels should be white, not black.
"""
rgba = _as_batched_image(rgba)
if rgba.shape[-1] != 4:
raise ValueError("Expected RGBA tensor for _white_where_alpha_zero")
rgb = rgba[..., :3]
a = rgba[..., 3:4]
white = torch.ones_like(rgb)
rgb = torch.where(a <= _EPS, white, rgb)
return torch.cat([rgb, a], dim=-1)
class _CropoutBase:
SIZE = None # override
@classmethod
def INPUT_TYPES(cls):
size = int(cls.SIZE)
# Defaults assume the 768x1344 reference. Still works if input differs; coords are clamped.
return {
"required": {
"image": ("IMAGE",),
"x": ("INT", {"default": 0, "min": 0, "max": max(0, REF_W - size), "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": max(0, REF_H - size), "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "cropout"
CATEGORY = "image/CropoutExpand"
def cropout(self, image, x, y):
img = _ensure_rgb(image) # RGB only output
b, h, w, _ = img.shape
size = int(self.SIZE)
x, y = _clamp_top_left(x, y, size, w, h)
patch = img[:, y : y + size, x : x + size, :].contiguous()
return (patch,)
class _ExpandBase:
SIZE = None # override
@classmethod
def INPUT_TYPES(cls):
size = int(cls.SIZE)
return {
"required": {
"rect": ("IMAGE",),
"x": ("INT", {"default": 0, "min": 0, "max": max(0, REF_W - size), "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": max(0, REF_H - size), "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "expand"
CATEGORY = "image/CropoutExpand"
def expand(self, rect, x, y):
size = int(self.SIZE)
rect_rgba = _ensure_rgba(rect)
_rect_size_check(rect_rgba, size)
rect_rgba = _white_where_alpha_zero(rect_rgba)
# Output: 768x1344 RGBA, transparent + WHITE background
b = rect_rgba.shape[0]
out = torch.zeros((b, REF_H, REF_W, 4), device=rect_rgba.device, dtype=rect_rgba.dtype)
out[..., :3] = 1.0 # white
out[..., 3] = 0.0 # fully transparent
x, y = _clamp_top_left(x, y, size, REF_W, REF_H)
out[:, y : y + size, x : x + size, :] = rect_rgba
return (out,)
# ---- Concrete nodes (6 total) ----
class Cropout_Big_384(_CropoutBase):
SIZE = 384
class Cropout_Mid_192(_CropoutBase):
SIZE = 192
class Cropout_Small_96(_CropoutBase):
SIZE = 96
class Expand_Big_384(_ExpandBase):
SIZE = 384
class Expand_Mid_192(_ExpandBase):
SIZE = 192
class Expand_Small_96(_ExpandBase):
SIZE = 96
NODE_CLASS_MAPPINGS = {
"Cropout_Big_384": Cropout_Big_384,
"Cropout_Mid_192": Cropout_Mid_192,
"Cropout_Small_96": Cropout_Small_96,
"Expand_Big_384": Expand_Big_384,
"Expand_Mid_192": Expand_Mid_192,
"Expand_Small_96": Expand_Small_96,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Cropout_Big_384": "Cropout_Big_384",
"Cropout_Mid_192": "Cropout_Mid_192",
"Cropout_Small_96": "Cropout_Small_96",
"Expand_Big_384": "Expand_Big_384",
"Expand_Mid_192": "Expand_Mid_192",
"Expand_Small_96": "Expand_Small_96",
}