saliacoel commited on
Commit
5988d40
·
verified ·
1 Parent(s): e636647

Upload FD_Standalone_2.py

Browse files
Files changed (1) hide show
  1. FD_Standalone_2.py +1087 -0
FD_Standalone_2.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import importlib.util
4
+
5
+ import torch
6
+ import numpy as np
7
+ from ultralytics import YOLO
8
+
9
+ from comfy_extras import nodes_differential_diffusion
10
+
11
+ # NEW: import comfy core nodes (for CLIPTextEncode)
12
+ try:
13
+ import nodes # comfy-core nodes.py
14
+ except Exception:
15
+ nodes = None
16
+
17
+
18
+ # -----------------------------
19
+ # helper loader
20
+ # -----------------------------
21
+ def _load_helpers():
22
+ here = os.path.dirname(os.path.abspath(__file__))
23
+
24
+ candidate_filenames = (
25
+ "Salia_Facedetailer_Helpers.py",
26
+ "Salia_Facedetailer_helpers.py",
27
+ "Facedetailer_helpers.py",
28
+ )
29
+
30
+ try:
31
+ from . import Salia_Facedetailer_Helpers as helpers # type: ignore
32
+ return helpers
33
+ except (ImportError, ModuleNotFoundError):
34
+ pass
35
+
36
+ for fname in candidate_filenames:
37
+ path = os.path.join(here, fname)
38
+ if os.path.isfile(path):
39
+ mod_name = os.path.splitext(fname)[0]
40
+ spec = importlib.util.spec_from_file_location(mod_name, path)
41
+ if spec is None or spec.loader is None:
42
+ continue
43
+ module = importlib.util.module_from_spec(spec)
44
+ sys.modules[mod_name] = module
45
+ spec.loader.exec_module(module)
46
+ return module
47
+
48
+ if here not in sys.path:
49
+ sys.path.insert(0, here)
50
+
51
+ import Salia_Facedetailer_Helpers as helpers # type: ignore
52
+ return helpers
53
+
54
+
55
+ helpers = _load_helpers()
56
+
57
+ # Make sure the helpers module is always importable under this canonical name
58
+ # (needed because we inlined TRT code that imports SEG from Salia_Facedetailer_Helpers)
59
+ try:
60
+ if "Salia_Facedetailer_Helpers" not in sys.modules:
61
+ sys.modules["Salia_Facedetailer_Helpers"] = helpers
62
+ except Exception:
63
+ pass
64
+
65
+
66
+ # -----------------------------
67
+ # Lazy import for TRT_D_HYPA (TRT VAE decoder)
68
+ # -----------------------------
69
+ _TRTHYPA_MODULE = None
70
+ _TRTHYPA_DECODER_1344x768 = None
71
+
72
+
73
+ def _load_trt_d_hypa_module():
74
+ """
75
+ Locate and import TRT_D_HYPA.py from the comfyui-TRT_VAE custom node.
76
+ We intentionally resolve it via filesystem paths so we do not depend on
77
+ how ComfyUI chooses to package/import custom nodes.
78
+ """
79
+ here = os.path.dirname(os.path.abspath(__file__))
80
+
81
+ # FD_Standalone.py: .../custom_nodes/comfyui-salia_facedetailer/nodes/FD_Standalone.py
82
+ # -> custom_nodes
83
+ custom_nodes_dir = os.path.dirname(os.path.dirname(here))
84
+ trt_nodes_dir = os.path.join(custom_nodes_dir, "comfyui-TRT_VAE", "nodes")
85
+ trt_file = os.path.join(trt_nodes_dir, "TRT_D_HYPA.py")
86
+
87
+ if not os.path.isfile(trt_file):
88
+ return None
89
+
90
+ mod_name = "TRT_D_HYPA"
91
+
92
+ # Reuse already-loaded module if present
93
+ existing = sys.modules.get(mod_name)
94
+ if existing is not None:
95
+ return existing
96
+
97
+ spec = importlib.util.spec_from_file_location(mod_name, trt_file)
98
+ if spec is None or spec.loader is None:
99
+ return None
100
+
101
+ module = importlib.util.module_from_spec(spec)
102
+ sys.modules[mod_name] = module
103
+ try:
104
+ spec.loader.exec_module(module)
105
+ except Exception:
106
+ # If import fails, remove the partially-loaded module to avoid poisoning sys.modules
107
+ sys.modules.pop(mod_name, None)
108
+ raise
109
+
110
+ return module
111
+
112
+
113
+ def _get_trt_decoder_1344x768():
114
+ """
115
+ Return a singleton instance of TRT_D_HYPA_1344x768 (lazy-created).
116
+ This keeps TensorRT engine initialization and memory allocations
117
+ outside of the ComfyUI graph definition path and only runs them
118
+ when the node is actually executed.
119
+ """
120
+ global _TRTHYPA_MODULE, _TRTHYPA_DECODER_1344x768
121
+
122
+ if _TRTHYPA_DECODER_1344x768 is not None:
123
+ return _TRTHYPA_DECODER_1344x768
124
+
125
+ if _TRTHYPA_MODULE is None:
126
+ _TRTHYPA_MODULE = _load_trt_d_hypa_module()
127
+
128
+ if _TRTHYPA_MODULE is None:
129
+ raise ImportError(
130
+ "[FD_Standalone] Could not locate TRT_D_HYPA.py under comfyui-TRT_VAE/nodes. "
131
+ "Make sure the comfyui-TRT_VAE custom node is installed."
132
+ )
133
+
134
+ try:
135
+ DecoderCls = getattr(_TRTHYPA_MODULE, "TRT_D_HYPA_1344x768")
136
+ except AttributeError as exc:
137
+ raise ImportError(
138
+ "[FD_Standalone] TRT_D_HYPA_1344x768 class not found inside TRT_D_HYPA.py."
139
+ ) from exc
140
+
141
+ _TRTHYPA_DECODER_1344x768 = DecoderCls()
142
+ return _TRTHYPA_DECODER_1344x768
143
+
144
+
145
+ # -----------------------------
146
+ # Lazy import for Salia_FD_Parsed.py (NEXT TO THIS FILE)
147
+ # -----------------------------
148
+ _SALIA_FD_PARSED_MODULE = None
149
+ _SALIA_PARSED_NODE = None
150
+
151
+
152
+ def _load_salia_fd_parsed_module():
153
+ """
154
+ Load Salia_FD_Parsed.py from the same directory as this file (relative import-by-path).
155
+ This remains valid if you move both files together to another folder.
156
+ """
157
+ global _SALIA_FD_PARSED_MODULE
158
+
159
+ here = os.path.dirname(os.path.abspath(__file__))
160
+ parsed_file = os.path.join(here, "Salia_FD_Parsed.py")
161
+
162
+ if not os.path.isfile(parsed_file):
163
+ raise FileNotFoundError(
164
+ f"[FD_Standalone] Missing Salia_FD_Parsed.py next to FD_Standalone.py.\n"
165
+ f"Expected: {parsed_file}"
166
+ )
167
+
168
+ mod_name = "Salia_FD_Parsed"
169
+
170
+ existing = sys.modules.get(mod_name)
171
+ if existing is not None:
172
+ try:
173
+ existing_file = os.path.abspath(getattr(existing, "__file__", "") or "")
174
+ if existing_file == os.path.abspath(parsed_file) and hasattr(existing, "Salia_Parsed"):
175
+ _SALIA_FD_PARSED_MODULE = existing
176
+ return existing
177
+ except Exception:
178
+ pass
179
+
180
+ spec = importlib.util.spec_from_file_location(mod_name, parsed_file)
181
+ if spec is None or spec.loader is None:
182
+ raise ImportError(f"[FD_Standalone] Failed to create import spec for: {parsed_file}")
183
+
184
+ module = importlib.util.module_from_spec(spec)
185
+ sys.modules[mod_name] = module
186
+ try:
187
+ spec.loader.exec_module(module)
188
+ except Exception:
189
+ sys.modules.pop(mod_name, None)
190
+ raise
191
+
192
+ if not hasattr(module, "Salia_Parsed"):
193
+ raise ImportError(
194
+ f"[FD_Standalone] Loaded {parsed_file}, but it does not define Salia_Parsed."
195
+ )
196
+
197
+ _SALIA_FD_PARSED_MODULE = module
198
+ return module
199
+
200
+
201
+ def _get_salia_parsed_node():
202
+ """Return a singleton instance of Salia_Parsed (lazy-created)."""
203
+ global _SALIA_PARSED_NODE
204
+
205
+ if _SALIA_PARSED_NODE is not None:
206
+ return _SALIA_PARSED_NODE
207
+
208
+ module = _load_salia_fd_parsed_module()
209
+ ParserCls = getattr(module, "Salia_Parsed", None)
210
+ if ParserCls is None:
211
+ raise ImportError("[FD_Standalone] Salia_Parsed class not found in Salia_FD_Parsed.py.")
212
+
213
+ _SALIA_PARSED_NODE = ParserCls()
214
+ return _SALIA_PARSED_NODE
215
+
216
+
217
+ # =====================================================================================
218
+ # INLINED: Salia_TRT_face.py (everything except the node wrapper)
219
+ # =====================================================================================
220
+
221
+ # Shared SEG definition (same fields as in Facedetailer_helpers)
222
+ try:
223
+ from .Salia_Facedetailer_Helpers import SEG
224
+ except ImportError:
225
+ # Fallback if used outside of a package
226
+ from Salia_Facedetailer_Helpers import SEG
227
+
228
+
229
+ # -------------------------------------------------------------------------
230
+ # Constants
231
+ # -------------------------------------------------------------------------
232
+
233
+ NODE_DIR = os.path.dirname(os.path.abspath(__file__))
234
+
235
+ # Engine is always this exact filename, located next to this .py file
236
+ ENGINE_FILENAME = "salia_face.engine"
237
+
238
+ # Optional: cache to avoid re-loading the engine every execution
239
+ _YOLO_ENGINE_CACHE = {}
240
+
241
+
242
+ def load_yolo_detect(model_path: str) -> YOLO:
243
+ """
244
+ Load a YOLO model with task explicitly set to 'detect' to suppress:
245
+ WARNING ⚠️ Unable to automatically guess model task...
246
+ Works across Ultralytics versions by falling back if 'task=' isn't supported.
247
+ """
248
+ try:
249
+ m = YOLO(model_path, task="detect")
250
+ except TypeError:
251
+ # Older Ultralytics versions may not accept 'task=' in the constructor
252
+ m = YOLO(model_path)
253
+
254
+ # Reinforce task in case the backend/model doesn't carry task metadata (e.g. TRT engine)
255
+ try:
256
+ m.task = "detect"
257
+ except Exception:
258
+ pass
259
+
260
+ try:
261
+ if hasattr(m, "overrides") and isinstance(m.overrides, dict):
262
+ m.overrides["task"] = "detect"
263
+ except Exception:
264
+ pass
265
+
266
+ return m
267
+
268
+
269
+ def load_engine_model(engine_path: str) -> YOLO:
270
+ """Load (and cache) the TensorRT engine as a YOLO detect model."""
271
+ m = _YOLO_ENGINE_CACHE.get(engine_path)
272
+ if m is None:
273
+ m = load_yolo_detect(engine_path)
274
+ _YOLO_ENGINE_CACHE[engine_path] = m
275
+ return m
276
+
277
+
278
+ # -------------------------------------------------------------------------
279
+ # Helpers (mirrors Salia_BBOX.py behavior)
280
+ # -------------------------------------------------------------------------
281
+
282
+
283
+ def tensor_to_pil(image: torch.Tensor):
284
+ """Convert a ComfyUI IMAGE tensor [B,H,W,C] (0..1) to a PIL RGB image (first item in batch)."""
285
+ from PIL import Image
286
+
287
+ if not isinstance(image, torch.Tensor):
288
+ raise TypeError(f"Expected torch.Tensor, got {type(image)}")
289
+
290
+ if image.dim() == 4:
291
+ img = image[0]
292
+ else:
293
+ img = image
294
+
295
+ img = img.detach()
296
+ if img.is_cuda:
297
+ img = img.cpu()
298
+
299
+ img = img.clamp(0, 1).numpy()
300
+ if img.shape[-1] == 1:
301
+ img = np.repeat(img, 3, axis=-1)
302
+
303
+ img_u8 = (img * 255.0).round().astype(np.uint8)
304
+ return Image.fromarray(img_u8)
305
+
306
+
307
+ def make_crop_region(w: int, h: int, bbox_xyxy, crop_factor: float, crop_min_size=None):
308
+ """Expanded bbox crop-region logic, clamped to image."""
309
+ try:
310
+ x1f = float(bbox_xyxy[0])
311
+ y1f = float(bbox_xyxy[1])
312
+ x2f = float(bbox_xyxy[2])
313
+ y2f = float(bbox_xyxy[3])
314
+ except Exception:
315
+ x1f = y1f = x2f = y2f = 0.0
316
+
317
+ bbox_w = max(1.0, x2f - x1f)
318
+ bbox_h = max(1.0, y2f - y1f)
319
+
320
+ crop_w = bbox_w * float(crop_factor)
321
+ crop_h = bbox_h * float(crop_factor)
322
+
323
+ if crop_min_size is not None:
324
+ crop_w = max(crop_w, float(crop_min_size))
325
+ crop_h = max(crop_h, float(crop_min_size))
326
+
327
+ cx = (x1f + x2f) / 2.0
328
+ cy = (y1f + y2f) / 2.0
329
+
330
+ rx1 = int(round(cx - crop_w / 2.0))
331
+ ry1 = int(round(cy - crop_h / 2.0))
332
+ rx2 = int(round(cx + crop_w / 2.0))
333
+ ry2 = int(round(cy + crop_h / 2.0))
334
+
335
+ # clamp
336
+ rx1 = max(0, min(w - 1, rx1))
337
+ ry1 = max(0, min(h - 1, ry1))
338
+ rx2 = max(rx1 + 1, min(w, rx2))
339
+ ry2 = max(ry1 + 1, min(h, ry2))
340
+
341
+ return (rx1, ry1, rx2, ry2)
342
+
343
+
344
+ def crop_image(image: torch.Tensor, crop_region):
345
+ """Crop a ComfyUI IMAGE tensor [B,H,W,C] using (x1,y1,x2,y2)."""
346
+ x1, y1, x2, y2 = crop_region
347
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
348
+
349
+ if image.dim() == 4:
350
+ return image[:, y1:y2, x1:x2, :]
351
+ if image.dim() == 3:
352
+ return image[y1:y2, x1:x2, :]
353
+ raise ValueError(f"Unexpected image tensor shape: {tuple(image.shape)}")
354
+
355
+
356
+ def crop_ndarray2(arr: np.ndarray, crop_region):
357
+ """Crop a 2D numpy array using (x1,y1,x2,y2)."""
358
+ x1, y1, x2, y2 = crop_region
359
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
360
+ return arr[y1:y2, x1:x2]
361
+
362
+
363
+ try:
364
+ import cv2 # opencv-python or opencv-python-headless
365
+ except Exception:
366
+ cv2 = None
367
+
368
+
369
+ def dilate_masks(segmasks, dilation: int):
370
+ """Dilate masks only if dilation > 0 and cv2 is available."""
371
+ if dilation <= 0:
372
+ return segmasks
373
+ if cv2 is None:
374
+ return segmasks
375
+
376
+ k = int(dilation)
377
+ ksize = k * 2 + 1
378
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
379
+
380
+ out = []
381
+ for bbox, mask, conf in segmasks:
382
+ try:
383
+ m = (mask > 0.5).astype(np.uint8) * 255
384
+ m = cv2.dilate(m, kernel, iterations=1)
385
+ out_mask = (m > 0).astype(np.float32)
386
+ out.append((bbox, out_mask, conf))
387
+ except Exception:
388
+ out.append((bbox, mask, conf))
389
+ return out
390
+
391
+
392
+ def combine_masks(segmasks, out_shape_hw=None) -> torch.Tensor:
393
+ """Combine multiple masks using max()."""
394
+ if not segmasks:
395
+ if out_shape_hw is None:
396
+ return torch.zeros((1, 1, 1), dtype=torch.float32)
397
+ h, w = out_shape_hw
398
+ return torch.zeros((1, h, w), dtype=torch.float32)
399
+
400
+ base = segmasks[0][1]
401
+ combined = np.zeros_like(base, dtype=np.float32)
402
+ for _, m, _ in segmasks:
403
+ try:
404
+ combined = np.maximum(combined, m.astype(np.float32))
405
+ except Exception:
406
+ pass
407
+
408
+ return torch.from_numpy(combined).unsqueeze(0)
409
+
410
+
411
+ def _create_segmasks(results):
412
+ """Create list of (bbox, mask_float32, conf)."""
413
+ bboxes = results[1]
414
+ segms = results[2]
415
+ confs = results[3]
416
+
417
+ out = []
418
+ try:
419
+ n = int(len(segms))
420
+ except Exception:
421
+ n = 0
422
+
423
+ for i in range(n):
424
+ try:
425
+ out.append((bboxes[i], segms[i].astype(np.float32), confs[i]))
426
+ except Exception:
427
+ pass
428
+
429
+ return out
430
+
431
+
432
+ def _inference_bbox(model, image_pil, confidence: float = 0.3, device: str = "0"):
433
+ """
434
+ Run bbox inference and return:
435
+ [labels, bboxes_xyxy_list, segm_masks_list, confs_list]
436
+ Where segm_masks are full-image boolean masks (rectangle fill per bbox).
437
+ """
438
+ pred = model(image_pil, conf=float(confidence), device=str(device), verbose=False)
439
+
440
+ bboxes = pred[0].boxes.xyxy.cpu().numpy() # xyxy
441
+ if bboxes is None or (hasattr(bboxes, "shape") and bboxes.shape[0] == 0):
442
+ return [[], [], [], []]
443
+
444
+ # Original image size (H, W)
445
+ w_orig, h_orig = image_pil.size
446
+ ih = int(h_orig)
447
+ iw = int(w_orig)
448
+
449
+ segms = []
450
+ for (x0, y0, x1, y1) in bboxes:
451
+ m = np.zeros((ih, iw), dtype=np.uint8)
452
+
453
+ # Clamp coords
454
+ try:
455
+ x0i = int(x0)
456
+ except Exception:
457
+ x0i = 0
458
+ try:
459
+ y0i = int(y0)
460
+ except Exception:
461
+ y0i = 0
462
+ try:
463
+ x1i = int(x1)
464
+ except Exception:
465
+ x1i = 0
466
+ try:
467
+ y1i = int(y1)
468
+ except Exception:
469
+ y1i = 0
470
+
471
+ x0c = max(0, min(iw - 1, x0i))
472
+ x1c = max(x0c + 1, min(iw, x1i))
473
+ y0c = max(0, min(ih - 1, y0i))
474
+ y1c = max(y0c + 1, min(ih, y1i))
475
+
476
+ if cv2 is not None:
477
+ try:
478
+ cv2.rectangle(m, (x0c, y0c), (x1c, y1c), 255, -1)
479
+ except Exception:
480
+ m[y0c:y1c, x0c:x1c] = 255
481
+ else:
482
+ m[y0c:y1c, x0c:x1c] = 255
483
+
484
+ segms.append((m > 0))
485
+
486
+ labels = []
487
+ confs = []
488
+
489
+ names = getattr(pred[0], "names", None)
490
+ names_is_seq = isinstance(names, (list, tuple))
491
+
492
+ for i in range(len(bboxes)):
493
+ # label
494
+ label = "unknown"
495
+ try:
496
+ cls_idx = int(pred[0].boxes[i].cls.item())
497
+ if names_is_seq:
498
+ label = names[cls_idx] if 0 <= cls_idx < len(names) else str(cls_idx)
499
+ elif isinstance(names, dict):
500
+ label = names.get(cls_idx, str(cls_idx))
501
+ else:
502
+ label = str(cls_idx)
503
+ except Exception:
504
+ label = "unknown"
505
+
506
+ # conf (force to float)
507
+ try:
508
+ conf_val = float(pred[0].boxes[i].conf.item())
509
+ except Exception:
510
+ conf_val = 0.0
511
+
512
+ labels.append(conf_val) # NOTE: kept as-is from your original code
513
+ confs.append(conf_val)
514
+
515
+ return [labels, list(bboxes), segms, confs]
516
+
517
+
518
+ # -------------------------------------------------------------------------
519
+ # YOLO TensorRT-based BBOX_DETECTOR implementation
520
+ # -------------------------------------------------------------------------
521
+
522
+
523
+ class TRTYOLOBBoxDetector:
524
+ """BBOX_DETECTOR interface compatible with FaceDetailer."""
525
+
526
+ def __init__(self, yolo_model: YOLO, device: str = "0"):
527
+ self.bbox_model = yolo_model
528
+ self.device = device or "0"
529
+
530
+ def setAux(self, x: str):
531
+ # Kept for interface compatibility
532
+ pass
533
+
534
+ def detect(
535
+ self,
536
+ image: torch.Tensor,
537
+ threshold: float,
538
+ dilation: int,
539
+ crop_factor: float,
540
+ drop_size: int = 1,
541
+ detailer_hook=None,
542
+ ):
543
+ """Return FaceDetailer-style SEGS: ( (H, W), [SEG, ...] )."""
544
+ if not isinstance(image, torch.Tensor):
545
+ raise TypeError(f"[TRTYOLOBBoxDetector] Expected torch.Tensor for image, got {type(image)}")
546
+ if image.dim() != 4:
547
+ raise ValueError("[TRTYOLOBBoxDetector] Expected IMAGE tensor with 4 dims [B, H, W, C].")
548
+
549
+ h, w = int(image.shape[1]), int(image.shape[2])
550
+ shape = (h, w)
551
+
552
+ detected = _inference_bbox(
553
+ self.bbox_model,
554
+ tensor_to_pil(image),
555
+ confidence=float(threshold),
556
+ device=str(self.device),
557
+ )
558
+
559
+ segmasks = _create_segmasks(detected)
560
+
561
+ if int(dilation) > 0:
562
+ segmasks = dilate_masks(segmasks, int(dilation))
563
+
564
+ drop_size_int = int(drop_size) if int(drop_size) > 0 else 1
565
+
566
+ items = []
567
+ for (bbox, mask, conf), label in zip(segmasks, detected[0]):
568
+ try:
569
+ x1f = float(bbox[0])
570
+ y1f = float(bbox[1])
571
+ x2f = float(bbox[2])
572
+ y2f = float(bbox[3])
573
+ except Exception:
574
+ continue
575
+
576
+ bwf = x2f - x1f
577
+ bhf = y2f - y1f
578
+
579
+ if bwf > drop_size_int and bhf > drop_size_int:
580
+ crop_region = make_crop_region(w, h, bbox, float(crop_factor))
581
+
582
+ if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
583
+ try:
584
+ crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region)
585
+ except Exception:
586
+ pass
587
+
588
+ cropped_image = crop_image(image, crop_region)
589
+ cropped_mask = crop_ndarray2(mask, crop_region)
590
+
591
+ items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None))
592
+
593
+ segs = (shape, items)
594
+
595
+ if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
596
+ try:
597
+ segs = detailer_hook.post_detection(segs)
598
+ except Exception:
599
+ pass
600
+
601
+ return segs
602
+
603
+ def detect_combined(self, image: torch.Tensor, threshold: float, dilation: int) -> torch.Tensor:
604
+ """Return a single combined MASK tensor covering all detections."""
605
+ if not isinstance(image, torch.Tensor):
606
+ raise TypeError(f"[TRTYOLOBBoxDetector] Expected torch.Tensor for image, got {type(image)}")
607
+ if image.dim() != 4:
608
+ raise ValueError("[TRTYOLOBBoxDetector] Expected IMAGE tensor with 4 dims [B, H, W, C].")
609
+
610
+ detected = _inference_bbox(
611
+ self.bbox_model,
612
+ tensor_to_pil(image),
613
+ confidence=float(threshold),
614
+ device=str(self.device),
615
+ )
616
+
617
+ segmasks = _create_segmasks(detected)
618
+ if int(dilation) > 0:
619
+ segmasks = dilate_masks(segmasks, int(dilation))
620
+
621
+ return combine_masks(segmasks, out_shape_hw=(int(image.shape[1]), int(image.shape[2])))
622
+
623
+
624
+ # =====================================================================================
625
+ # END INLINED: Salia_TRT_face.py
626
+ # =====================================================================================
627
+
628
+
629
+ # -----------------------------
630
+ # CLIP Text Encode (core) wrapper
631
+ # -----------------------------
632
+ _CLIP_TEXT_ENCODE_NODE = None
633
+
634
+
635
+ def _encode_conditioning(clip, text: str):
636
+ """
637
+ Uses comfy-core CLIPTextEncode node (preferred), with a robust fallback for older/newer core APIs.
638
+ """
639
+ global _CLIP_TEXT_ENCODE_NODE
640
+
641
+ if text is None:
642
+ text = ""
643
+
644
+ # Preferred: call comfy-core node CLIPTextEncode
645
+ if nodes is not None:
646
+ if _CLIP_TEXT_ENCODE_NODE is None:
647
+ _CLIP_TEXT_ENCODE_NODE = nodes.CLIPTextEncode()
648
+
649
+ # Core node returns a tuple: (conditioning,)
650
+ return _CLIP_TEXT_ENCODE_NODE.encode(clip=clip, text=text)[0]
651
+
652
+ # Fallback if for some reason `import nodes` failed in your environment:
653
+ if clip is None:
654
+ raise RuntimeError("CLIP input is None (cannot encode).")
655
+
656
+ tokens = clip.tokenize(text)
657
+
658
+ # Newer-ish API (2024/2025+)
659
+ if hasattr(clip, "encode_from_tokens_scheduled"):
660
+ return clip.encode_from_tokens_scheduled(tokens)
661
+
662
+ # Older API fallback
663
+ output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
664
+ cond = output.pop("cond")
665
+ return [[cond, output]]
666
+
667
+
668
+ def _manual_bbox_from_ltrb(left, top, right, bottom):
669
+ """
670
+ Manual bbox override from 4 ints: (left, top, right, bottom).
671
+
672
+ These 4 ints imply the 4 corners:
673
+ - Top-left = (left, top)
674
+ - Top-right = (right, top)
675
+ - Bottom-left = (left, bottom)
676
+ - Bottom-right = (right, bottom)
677
+
678
+ Convention:
679
+ - If ANY value is None or < 0 -> return None (use YOLO detection).
680
+ - Otherwise returns (x1, y1, x2, y2) with correct ordering.
681
+ """
682
+ if left is None or top is None or right is None or bottom is None:
683
+ return None
684
+
685
+ try:
686
+ x1 = int(left)
687
+ y1 = int(top)
688
+ x2 = int(right)
689
+ y2 = int(bottom)
690
+ except Exception:
691
+ return None
692
+
693
+ # Sentinel: any negative => auto detect
694
+ if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0:
695
+ return None
696
+
697
+ # Ensure proper ordering
698
+ if x2 < x1:
699
+ x1, x2 = x2, x1
700
+ if y2 < y1:
701
+ y1, y2 = y2, y1
702
+
703
+ return (x1, y1, x2, y2)
704
+
705
+
706
+ class FD_Standalone_2:
707
+ _BBOX_DETECTOR = None
708
+
709
+ @classmethod
710
+ def INPUT_TYPES(cls):
711
+ return {
712
+ "required": {
713
+ # CHANGED: take latent instead of image, and internally decode via TRT_D_HYPA_1344x768
714
+ "latent": (
715
+ "LATENT",
716
+ {
717
+ "tooltip": "Latent to be decoded with TRT_D_HYPA_1344x768 before face detailing."
718
+ },
719
+ ),
720
+ "model": ("MODEL", {"tooltip": "If ImpactDummyInput connected, inference may be skipped."}),
721
+ # single CLIP input (from Load Checkpoint)
722
+ "clip": ("CLIP", {"tooltip": "CLIP from Load Checkpoint (SDXL CLIP is fine)."}),
723
+
724
+ # NEW: manual bbox override via 4 ints (left/top/right/bottom)
725
+ # Leave any value at -1 to use YOLO auto-detection.
726
+ "bbox_left": (
727
+ "INT",
728
+ {
729
+ "default": -1,
730
+ "min": -1,
731
+ "max": 1000000,
732
+ "step": 1,
733
+ "tooltip": "Manual bbox LEFT (x1). Top-left=(LEFT,TOP), Bottom-left=(LEFT,BOTTOM). -1 => YOLO auto-detect.",
734
+ },
735
+ ),
736
+ "bbox_top": (
737
+ "INT",
738
+ {
739
+ "default": -1,
740
+ "min": -1,
741
+ "max": 1000000,
742
+ "step": 1,
743
+ "tooltip": "Manual bbox TOP (y1). Top-left=(LEFT,TOP), Top-right=(RIGHT,TOP). -1 => YOLO auto-detect.",
744
+ },
745
+ ),
746
+ "bbox_right": (
747
+ "INT",
748
+ {
749
+ "default": -1,
750
+ "min": -1,
751
+ "max": 1000000,
752
+ "step": 1,
753
+ "tooltip": "Manual bbox RIGHT (x2). Top-right=(RIGHT,TOP), Bottom-right=(RIGHT,BOTTOM). -1 => YOLO auto-detect.",
754
+ },
755
+ ),
756
+ "bbox_bottom": (
757
+ "INT",
758
+ {
759
+ "default": -1,
760
+ "min": -1,
761
+ "max": 1000000,
762
+ "step": 1,
763
+ "tooltip": "Manual bbox BOTTOM (y2). Bottom-left=(LEFT,BOTTOM), Bottom-right=(RIGHT,BOTTOM). -1 => YOLO auto-detect.",
764
+ },
765
+ ),
766
+
767
+ # POV integer
768
+ "pov_id": (
769
+ "INT",
770
+ {
771
+ "default": 1,
772
+ "min": 1,
773
+ "max": 4,
774
+ "step": 1,
775
+ "tooltip": "POV: 1=front, 2=three-quarter, 3=side, 4=rear. If 4, node bypasses and outputs decoded image unchanged.",
776
+ },
777
+ ),
778
+
779
+ # single input string, internally parsed by Salia_Parsed into (pos, neg)
780
+ "prompt": (
781
+ "STRING",
782
+ {
783
+ "multiline": True,
784
+ "default": "",
785
+ "dynamicPrompts": True,
786
+ "tooltip": "Single prompt string. Internally parsed by Salia_Parsed into (pos, neg) for face detailing.",
787
+ },
788
+ ),
789
+ },
790
+ "optional": {},
791
+ }
792
+
793
+ RETURN_TYPES = ("IMAGE",)
794
+ RETURN_NAMES = ("image",)
795
+ OUTPUT_IS_LIST = (False,)
796
+ FUNCTION = "doit"
797
+ CATEGORY = "ImpactPack/Simple"
798
+
799
+ @classmethod
800
+ def _get_bbox_detector(cls):
801
+ if cls._BBOX_DETECTOR is not None:
802
+ return cls._BBOX_DETECTOR
803
+
804
+ engine_path = os.path.join(NODE_DIR, ENGINE_FILENAME)
805
+
806
+ if not os.path.isfile(engine_path):
807
+ raise FileNotFoundError(
808
+ f"[TRTYOLOBBoxDetectorProvider] Engine file not found: {engine_path}\n"
809
+ f"Expected the file '{ENGINE_FILENAME}' next to this node .py file."
810
+ )
811
+
812
+ yolo_model = load_engine_model(engine_path)
813
+ detector = TRTYOLOBBoxDetector(yolo_model, device="0")
814
+ cls._BBOX_DETECTOR = detector
815
+ return cls._BBOX_DETECTOR
816
+
817
+ @staticmethod
818
+ def enhance_face(image, model, positive, negative, bbox_detector=None, manual_bbox=None):
819
+ """
820
+ If manual_bbox is provided (x1,y1,x2,y2), skip detector and detail only that region.
821
+ Otherwise use bbox_detector.detect(...) (original behavior).
822
+ """
823
+ # Manual override path
824
+ if manual_bbox is not None:
825
+ try:
826
+ return DetailerForEach.do_detail_bbox(image, manual_bbox, model, positive, negative)
827
+ except Exception:
828
+ return image
829
+
830
+ # Original detection path
831
+ if bbox_detector is None:
832
+ return image
833
+
834
+ try:
835
+ bbox_detector.setAux("face")
836
+ except Exception:
837
+ pass
838
+
839
+ try:
840
+ segs = bbox_detector.detect(image, 0.55, 0, 1.0, 10)
841
+ except Exception:
842
+ try:
843
+ bbox_detector.setAux(None)
844
+ except Exception:
845
+ pass
846
+ return image
847
+
848
+ try:
849
+ bbox_detector.setAux(None)
850
+ except Exception:
851
+ pass
852
+
853
+ try:
854
+ num_segs = int(len(segs[1]))
855
+ except Exception:
856
+ num_segs = 0
857
+
858
+ if num_segs == 0:
859
+ return image
860
+
861
+ try:
862
+ out = DetailerForEach.do_detail(image, segs, model, positive, negative)
863
+ return out
864
+ except Exception:
865
+ return image
866
+
867
+ def doit(self, latent, model, clip, bbox_left, bbox_top, bbox_right, bbox_bottom, pov_id, prompt):
868
+ # Step 1: decode latent -> image using the TRT VAE decoder
869
+ decoder = _get_trt_decoder_1344x768()
870
+ decoded = decoder.decode(latent)
871
+ if isinstance(decoded, (list, tuple)):
872
+ image = decoded[0]
873
+ else:
874
+ image = decoded
875
+
876
+ # Normalize POV (1..4)
877
+ try:
878
+ pov_id_int = int(pov_id)
879
+ except Exception:
880
+ pov_id_int = 1
881
+ if pov_id_int < 1:
882
+ pov_id_int = 1
883
+ if pov_id_int > 4:
884
+ pov_id_int = 4
885
+
886
+ # POV=4 (rear view): skip entire task and output decoded image unchanged
887
+ if pov_id_int == 4:
888
+ return (image,)
889
+
890
+ # Parse the single prompt string -> (pos, neg)
891
+ if prompt is None:
892
+ prompt = ""
893
+ parser = _get_salia_parsed_node()
894
+ try:
895
+ pos, neg = parser.run(pov_id_int, prompt)
896
+ except Exception as exc:
897
+ raise RuntimeError(f"[FD_Standalone] Salia_Parsed failed: {exc}") from exc
898
+
899
+ # Encode ONCE per node execution (not per face / not per segment)
900
+ skip_inference = isinstance(model, str) and model == "DUMMY"
901
+
902
+ if skip_inference:
903
+ positive = []
904
+ negative = []
905
+ else:
906
+ positive = _encode_conditioning(clip, pos)
907
+ negative = _encode_conditioning(clip, neg)
908
+
909
+ # Decide manual bbox vs detector:
910
+ # If bbox_left/top/right/bottom are all >= 0 -> manual override.
911
+ # Otherwise -> YOLO detection.
912
+ manual_bbox = _manual_bbox_from_ltrb(bbox_left, bbox_top, bbox_right, bbox_bottom)
913
+
914
+ # Only load detector if needed
915
+ bbox_detector = None
916
+ if manual_bbox is None:
917
+ bbox_detector = FD_Standalone._get_bbox_detector()
918
+
919
+ outs = []
920
+ # Image from TRT VAE is [B,H,W,C]; iterate over batch dimension
921
+ for img in image:
922
+ try:
923
+ out = self.enhance_face(
924
+ img.unsqueeze(0),
925
+ model,
926
+ positive,
927
+ negative,
928
+ bbox_detector=bbox_detector,
929
+ manual_bbox=manual_bbox,
930
+ )
931
+ except Exception:
932
+ out = img.unsqueeze(0)
933
+ outs.append(out)
934
+
935
+ try:
936
+ result = torch.cat(outs, dim=0)
937
+ except Exception:
938
+ result = image
939
+
940
+ return (result,)
941
+
942
+
943
+ class DetailerForEach:
944
+ @staticmethod
945
+ def do_detail_bbox(image, bbox, model, positive, negative):
946
+ """
947
+ NEW: Detail exactly one bbox (x1,y1,x2,y2) without needing SEGS/detection.
948
+ Uses the same square-crop/detail/paste logic as do_detail().
949
+ """
950
+ try:
951
+ image = image.clone().cpu()
952
+ except Exception:
953
+ pass
954
+
955
+ # Clamp bbox to image bounds (best-effort safety)
956
+ try:
957
+ h = int(image.shape[1])
958
+ w = int(image.shape[2])
959
+ except Exception:
960
+ h, w = 0, 0
961
+
962
+ try:
963
+ x1, y1, x2, y2 = bbox
964
+ x1 = int(x1)
965
+ y1 = int(y1)
966
+ x2 = int(x2)
967
+ y2 = int(y2)
968
+ except Exception:
969
+ return image
970
+
971
+ if w > 0 and h > 0:
972
+ x1 = max(0, min(w - 1, x1))
973
+ y1 = max(0, min(h - 1, y1))
974
+ x2 = max(x1 + 1, min(w, x2))
975
+ y2 = max(y1 + 1, min(h, y2))
976
+
977
+ bbox_clamped = (x1, y1, x2, y2)
978
+
979
+ try:
980
+ model = nodes_differential_diffusion.DifferentialDiffusion().apply(model)[0]
981
+ except Exception:
982
+ pass
983
+
984
+ try:
985
+ rx1, ry1, side, _, _, _, _ = helpers.bbox_to_square_region(bbox_clamped, max_side=1024)
986
+ except Exception:
987
+ return image
988
+
989
+ square_patch = helpers.crop_with_pad_nhwc(image, rx1, ry1, side, fill=0.0)
990
+ if square_patch is None:
991
+ return image
992
+
993
+ try:
994
+ if square_patch is not None and not (isinstance(model, str) and model == "DUMMY"):
995
+ premult_side, alpha_side = helpers.enhance_detail_bbox_square(
996
+ square_patch,
997
+ model,
998
+ positive,
999
+ negative,
1000
+ side=side,
1001
+ )
1002
+ else:
1003
+ premult_side = square_patch
1004
+ alpha_side = torch.ones(
1005
+ (1, side, side, 1),
1006
+ dtype=square_patch.dtype,
1007
+ device=square_patch.device,
1008
+ )
1009
+ except Exception:
1010
+ return image
1011
+
1012
+ try:
1013
+ helpers.tensor_paste_premult_oob(image, premult_side, alpha_side, (rx1, ry1))
1014
+ except Exception:
1015
+ pass
1016
+
1017
+ try:
1018
+ out = helpers.tensor_convert_rgb(image)
1019
+ except Exception:
1020
+ out = image
1021
+
1022
+ return out
1023
+
1024
+ @staticmethod
1025
+ def do_detail(image, segs, model, positive, negative):
1026
+ try:
1027
+ image = image.clone().cpu()
1028
+ except Exception:
1029
+ pass
1030
+
1031
+ try:
1032
+ _, ordered_segs = helpers.segs_scale_match(segs, image.shape)
1033
+ except Exception:
1034
+ ordered_segs = segs[1] if (segs and len(segs) > 1) else []
1035
+
1036
+ try:
1037
+ model = nodes_differential_diffusion.DifferentialDiffusion().apply(model)[0]
1038
+ except Exception:
1039
+ pass
1040
+
1041
+ for seg in ordered_segs:
1042
+ try:
1043
+ rx1, ry1, side, _, _, _, _ = helpers.bbox_to_square_region(seg.bbox, max_side=1024)
1044
+ except Exception:
1045
+ continue
1046
+
1047
+ square_patch = helpers.crop_with_pad_nhwc(image, rx1, ry1, side, fill=0.0)
1048
+
1049
+ try:
1050
+ if square_patch is not None and not (isinstance(model, str) and model == "DUMMY"):
1051
+ premult_side, alpha_side = helpers.enhance_detail_bbox_square(
1052
+ square_patch,
1053
+ model,
1054
+ positive,
1055
+ negative,
1056
+ side=side,
1057
+ )
1058
+ else:
1059
+ premult_side = square_patch
1060
+ alpha_side = torch.ones(
1061
+ (1, side, side, 1),
1062
+ dtype=square_patch.dtype,
1063
+ device=square_patch.device,
1064
+ )
1065
+ except Exception:
1066
+ continue
1067
+
1068
+ try:
1069
+ helpers.tensor_paste_premult_oob(image, premult_side, alpha_side, (rx1, ry1))
1070
+ except Exception:
1071
+ pass
1072
+
1073
+ try:
1074
+ out = helpers.tensor_convert_rgb(image)
1075
+ except Exception:
1076
+ out = image
1077
+
1078
+ return out
1079
+
1080
+
1081
+ NODE_CLASS_MAPPINGS = {
1082
+ "FD_Standalone_2": FD_Standalone_2,
1083
+ }
1084
+
1085
+ NODE_DISPLAY_NAME_MAPPINGS = {
1086
+ "FD_Standalone_2": "FD_Standalone_2",
1087
+ }