saliacoel commited on
Commit
5902a91
·
verified ·
1 Parent(s): 15e94d2

Upload canvas_expand_crop.py

Browse files
Files changed (1) hide show
  1. canvas_expand_crop.py +209 -0
canvas_expand_crop.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # comfy_cropout_expand_nodes.py
2
+ # Put this file in: ComfyUI/custom_nodes/
3
+ # Restart ComfyUI after adding/updating.
4
+
5
+ import torch
6
+
7
+ REF_W = 768
8
+ REF_H = 1344
9
+
10
+ _EPS = 1e-6
11
+
12
+
13
+ def _as_batched_image(img: torch.Tensor) -> torch.Tensor:
14
+ """
15
+ ComfyUI IMAGE tensors are typically [B, H, W, C].
16
+ Accepts [H, W, C] as a fallback and converts to [1, H, W, C].
17
+ """
18
+ if not isinstance(img, torch.Tensor):
19
+ raise TypeError(f"Expected torch.Tensor, got {type(img)}")
20
+
21
+ if img.dim() == 4:
22
+ return img
23
+ if img.dim() == 3:
24
+ return img.unsqueeze(0)
25
+
26
+ raise ValueError(f"Expected IMAGE tensor with 3 or 4 dims, got shape {tuple(img.shape)}")
27
+
28
+
29
+ def _clamp_top_left(x: int, y: int, size: int, width: int, height: int) -> tuple[int, int]:
30
+ """
31
+ Clamp (x, y) so that a size x size square fits inside (width, height).
32
+ """
33
+ x = int(x)
34
+ y = int(y)
35
+ max_x = max(0, width - size)
36
+ max_y = max(0, height - size)
37
+ if x < 0:
38
+ x = 0
39
+ elif x > max_x:
40
+ x = max_x
41
+ if y < 0:
42
+ y = 0
43
+ elif y > max_y:
44
+ y = max_y
45
+ return x, y
46
+
47
+
48
+ def _ensure_rgb(img: torch.Tensor) -> torch.Tensor:
49
+ """
50
+ Accept RGB or RGBA, return RGB (drop alpha if present).
51
+ """
52
+ img = _as_batched_image(img)
53
+ c = img.shape[-1]
54
+ if c == 3:
55
+ return img
56
+ if c == 4:
57
+ return img[..., :3]
58
+ raise ValueError(f"Expected 3 or 4 channels, got {c} channels")
59
+
60
+
61
+ def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ Accept RGB or RGBA, return RGBA (add opaque alpha if missing).
64
+ """
65
+ img = _as_batched_image(img)
66
+ c = img.shape[-1]
67
+ if c == 4:
68
+ return img
69
+ if c == 3:
70
+ alpha = torch.ones((*img.shape[:-1], 1), device=img.device, dtype=img.dtype)
71
+ return torch.cat([img, alpha], dim=-1)
72
+ raise ValueError(f"Expected 3 or 4 channels, got {c} channels")
73
+
74
+
75
+ def _rect_size_check(rect: torch.Tensor, size: int) -> None:
76
+ rect = _as_batched_image(rect)
77
+ h = rect.shape[1]
78
+ w = rect.shape[2]
79
+ if h != size or w != size:
80
+ raise ValueError(f"Rect input must be {size}x{size}, got {w}x{h}.")
81
+
82
+
83
+ def _white_where_alpha_zero(rgba: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Ensures RGB is WHITE where alpha is (near) zero.
86
+ This matches the requirement: transparent pixels should be white, not black.
87
+ """
88
+ rgba = _as_batched_image(rgba)
89
+ if rgba.shape[-1] != 4:
90
+ raise ValueError("Expected RGBA tensor for _white_where_alpha_zero")
91
+
92
+ rgb = rgba[..., :3]
93
+ a = rgba[..., 3:4]
94
+ white = torch.ones_like(rgb)
95
+ rgb = torch.where(a <= _EPS, white, rgb)
96
+ return torch.cat([rgb, a], dim=-1)
97
+
98
+
99
+ class _CropoutBase:
100
+ SIZE = None # override
101
+
102
+ @classmethod
103
+ def INPUT_TYPES(cls):
104
+ size = int(cls.SIZE)
105
+ # Defaults assume the 768x1344 reference. Still works if input differs; coords are clamped.
106
+ return {
107
+ "required": {
108
+ "image": ("IMAGE",),
109
+ "x": ("INT", {"default": 0, "min": 0, "max": max(0, REF_W - size), "step": 1}),
110
+ "y": ("INT", {"default": 0, "min": 0, "max": max(0, REF_H - size), "step": 1}),
111
+ }
112
+ }
113
+
114
+ RETURN_TYPES = ("IMAGE",)
115
+ RETURN_NAMES = ("image",)
116
+ FUNCTION = "cropout"
117
+ CATEGORY = "image/CropoutExpand"
118
+
119
+ def cropout(self, image, x, y):
120
+ img = _ensure_rgb(image) # RGB only output
121
+ b, h, w, _ = img.shape
122
+ size = int(self.SIZE)
123
+
124
+ x, y = _clamp_top_left(x, y, size, w, h)
125
+ patch = img[:, y : y + size, x : x + size, :].contiguous()
126
+ return (patch,)
127
+
128
+
129
+ class _ExpandBase:
130
+ SIZE = None # override
131
+
132
+ @classmethod
133
+ def INPUT_TYPES(cls):
134
+ size = int(cls.SIZE)
135
+ return {
136
+ "required": {
137
+ "rect": ("IMAGE",),
138
+ "x": ("INT", {"default": 0, "min": 0, "max": max(0, REF_W - size), "step": 1}),
139
+ "y": ("INT", {"default": 0, "min": 0, "max": max(0, REF_H - size), "step": 1}),
140
+ }
141
+ }
142
+
143
+ RETURN_TYPES = ("IMAGE",)
144
+ RETURN_NAMES = ("image",)
145
+ FUNCTION = "expand"
146
+ CATEGORY = "image/CropoutExpand"
147
+
148
+ def expand(self, rect, x, y):
149
+ size = int(self.SIZE)
150
+
151
+ rect_rgba = _ensure_rgba(rect)
152
+ _rect_size_check(rect_rgba, size)
153
+
154
+ rect_rgba = _white_where_alpha_zero(rect_rgba)
155
+
156
+ # Output: 768x1344 RGBA, transparent + WHITE background
157
+ b = rect_rgba.shape[0]
158
+ out = torch.zeros((b, REF_H, REF_W, 4), device=rect_rgba.device, dtype=rect_rgba.dtype)
159
+ out[..., :3] = 1.0 # white
160
+ out[..., 3] = 0.0 # fully transparent
161
+
162
+ x, y = _clamp_top_left(x, y, size, REF_W, REF_H)
163
+ out[:, y : y + size, x : x + size, :] = rect_rgba
164
+ return (out,)
165
+
166
+
167
+ # ---- Concrete nodes (6 total) ----
168
+
169
+ class Cropout_Big_384(_CropoutBase):
170
+ SIZE = 384
171
+
172
+
173
+ class Cropout_Mid_192(_CropoutBase):
174
+ SIZE = 192
175
+
176
+
177
+ class Cropout_Small_96(_CropoutBase):
178
+ SIZE = 96
179
+
180
+
181
+ class Expand_Big_384(_ExpandBase):
182
+ SIZE = 384
183
+
184
+
185
+ class Expand_Mid_192(_ExpandBase):
186
+ SIZE = 192
187
+
188
+
189
+ class Expand_Small_96(_ExpandBase):
190
+ SIZE = 96
191
+
192
+
193
+ NODE_CLASS_MAPPINGS = {
194
+ "Cropout_Big_384": Cropout_Big_384,
195
+ "Cropout_Mid_192": Cropout_Mid_192,
196
+ "Cropout_Small_96": Cropout_Small_96,
197
+ "Expand_Big_384": Expand_Big_384,
198
+ "Expand_Mid_192": Expand_Mid_192,
199
+ "Expand_Small_96": Expand_Small_96,
200
+ }
201
+
202
+ NODE_DISPLAY_NAME_MAPPINGS = {
203
+ "Cropout_Big_384": "Cropout_Big_384",
204
+ "Cropout_Mid_192": "Cropout_Mid_192",
205
+ "Cropout_Small_96": "Cropout_Small_96",
206
+ "Expand_Big_384": "Expand_Big_384",
207
+ "Expand_Mid_192": "Expand_Mid_192",
208
+ "Expand_Small_96": "Expand_Small_96",
209
+ }