| """ |
| Salia Ultralytics Detector Provider (ComfyUI custom node) |
| |
| Goal: |
| - Provide the same outputs as Impact-Subpack's `UltralyticsDetectorProvider`: |
| - BBOX_DETECTOR |
| - SEGM_DETECTOR |
| - But packaged so you can drop it into your own custom node folder (your Salia_* environment) |
| without requiring ComfyUI-Impact-Subpack. |
| |
| Notes: |
| - This file intentionally keeps dependencies minimal and self-contained. |
| - It uses `ultralytics.YOLO` to run `.pt` models directly (no TensorRT build step). |
| - For PyTorch >= 2.6, `torch.load` defaults to `weights_only=True` which can break |
| legacy `.pt` checkpoints. This file adds an OPTIONAL whitelist-based fallback |
| to `weights_only=False` (unsafe) for specifically trusted model filenames. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import logging |
| import pickle |
| from datetime import datetime |
| from contextlib import contextmanager |
| from collections import namedtuple |
|
|
| import folder_paths |
|
|
| from PIL import Image |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| try: |
| import cv2 |
| except Exception: |
| cv2 = None |
|
|
|
|
| |
| |
| |
|
|
| _SUPPORTED_PT_EXTS = getattr(folder_paths, "supported_pt_extensions", [".pt", ".pth", ".ckpt", ".safetensors"]) |
|
|
|
|
| def _add_folder_path_and_extensions(folder_name: str, paths: list[str], extensions: list[str] | tuple[str, ...]): |
| """Add/merge a folder_paths entry without depending on Impact-Pack helpers.""" |
| if folder_name in folder_paths.folder_names_and_paths: |
| existing_paths, existing_exts = folder_paths.folder_names_and_paths[folder_name] |
| merged_paths = list(existing_paths) |
| for p in paths: |
| if p not in merged_paths: |
| merged_paths.append(p) |
| merged_exts = list(existing_exts) |
| for ext in extensions: |
| if ext not in merged_exts: |
| merged_exts.append(ext) |
| folder_paths.folder_names_and_paths[folder_name] = (merged_paths, tuple(merged_exts)) |
| else: |
| folder_paths.folder_names_and_paths[folder_name] = (list(paths), tuple(extensions)) |
|
|
|
|
| def _update_model_paths(base_path: str): |
| """Register standard Impact-Subpack ultralytics model locations.""" |
| _add_folder_path_and_extensions( |
| "ultralytics_bbox", |
| [os.path.join(base_path, "ultralytics", "bbox")], |
| _SUPPORTED_PT_EXTS, |
| ) |
| _add_folder_path_and_extensions( |
| "ultralytics_segm", |
| [os.path.join(base_path, "ultralytics", "segm")], |
| _SUPPORTED_PT_EXTS, |
| ) |
| _add_folder_path_and_extensions( |
| "ultralytics", |
| [os.path.join(base_path, "ultralytics")], |
| _SUPPORTED_PT_EXTS, |
| ) |
|
|
|
|
| |
| _update_model_paths(folder_paths.models_dir) |
| if "download_model_base" in folder_paths.folder_names_and_paths: |
| try: |
| _update_model_paths(folder_paths.get_folder_paths("download_model_base")[0]) |
| except Exception: |
| pass |
|
|
| |
| |
| _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) |
| for local_dir in [ |
| os.path.join(_THIS_DIR, "nodes"), |
| os.path.join(_THIS_DIR, "models"), |
| _THIS_DIR, |
| ]: |
| if os.path.isdir(local_dir): |
| _add_folder_path_and_extensions("ultralytics_bbox", [local_dir], _SUPPORTED_PT_EXTS) |
| _add_folder_path_and_extensions("ultralytics_segm", [local_dir], _SUPPORTED_PT_EXTS) |
| _add_folder_path_and_extensions("ultralytics", [local_dir], _SUPPORTED_PT_EXTS) |
|
|
|
|
| |
| |
| |
|
|
| _ORIG_TORCH_LOAD = torch.load |
|
|
|
|
| def _get_whitelist_file() -> str | None: |
| """Create/return the whitelist file path under ComfyUI's user directory.""" |
| try: |
| user_dir = folder_paths.get_user_directory() |
| except Exception: |
| user_dir = None |
|
|
| if not user_dir or not os.path.isdir(user_dir): |
| return None |
|
|
| wl_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics") |
| wl_file = os.path.join(wl_dir, "model-whitelist.txt") |
| try: |
| os.makedirs(wl_dir, exist_ok=True) |
| if not os.path.exists(wl_file): |
| with open(wl_file, "w", encoding="utf-8") as f: |
| f.write("# Add base filenames of trusted legacy models here (one per line).\n") |
| f.write("# Example: eyes.pt\n") |
| f.write("# These will be allowed to load with weights_only=False if safe loading fails.\n") |
| f.write("# WARNING: Only add models you trust.\n") |
| except Exception: |
| return None |
|
|
| return wl_file |
|
|
|
|
| _WHITELIST_PATH = _get_whitelist_file() |
|
|
|
|
| |
| |
| |
|
|
| def _get_model_load_log_file() -> str: |
| """ |
| Log file path used to record which ultralytics model file was actually loaded. |
| Prefer the same ComfyUI user dir used for the whitelist (if available). |
| """ |
| |
| if _WHITELIST_PATH: |
| base_dir = os.path.dirname(_WHITELIST_PATH) |
| return os.path.join(base_dir, "model-load-log.txt") |
|
|
| |
| try: |
| user_dir = folder_paths.get_user_directory() |
| except Exception: |
| user_dir = None |
|
|
| if user_dir and os.path.isdir(user_dir): |
| base_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics") |
| try: |
| os.makedirs(base_dir, exist_ok=True) |
| except Exception: |
| pass |
| return os.path.join(base_dir, "model-load-log.txt") |
|
|
| |
| return os.path.join(_THIS_DIR, "model-load-log.txt") |
|
|
|
|
| _MODEL_LOAD_LOG_PATH = _get_model_load_log_file() |
|
|
|
|
| def _find_all_model_paths(model_name: str) -> list[str]: |
| """ |
| Find all possible on-disk matches across the registered ultralytics folders. |
| Useful if the same filename exists in multiple locations. |
| """ |
| matches: list[str] = [] |
|
|
| try: |
| ultra_roots = folder_paths.get_folder_paths("ultralytics") |
| except Exception: |
| ultra_roots = [] |
|
|
| try: |
| bbox_roots = folder_paths.get_folder_paths("ultralytics_bbox") |
| except Exception: |
| bbox_roots = [] |
|
|
| try: |
| segm_roots = folder_paths.get_folder_paths("ultralytics_segm") |
| except Exception: |
| segm_roots = [] |
|
|
| def add_if_exists(root: str, rel: str): |
| p = os.path.join(root, rel) |
| if os.path.exists(p): |
| matches.append(os.path.abspath(p)) |
|
|
| |
| for r in ultra_roots: |
| add_if_exists(r, model_name) |
|
|
| |
| if model_name.startswith("bbox/"): |
| rel = model_name[5:] |
| for r in bbox_roots: |
| add_if_exists(r, rel) |
| elif model_name.startswith("segm/"): |
| rel = model_name[5:] |
| for r in segm_roots: |
| add_if_exists(r, rel) |
|
|
| |
| out: list[str] = [] |
| seen = set() |
| for p in matches: |
| if p not in seen: |
| seen.add(p) |
| out.append(p) |
| return out |
|
|
|
|
| def _log_selected_model(model_name: str, model_path: str, matches: list[str] | None = None): |
| """ |
| Prints the resolved model path to console AND appends it to a log file. |
| """ |
| |
| print(f"[Salia Ultralytics] Selected model_name: {model_name}") |
| print(f"[Salia Ultralytics] Resolved model_path: {model_path}") |
| if matches and len(matches) > 1: |
| print("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):") |
| for p in matches: |
| print(f" - {p}") |
| print(f"[Salia Ultralytics] Model load log file: {_MODEL_LOAD_LOG_PATH}") |
|
|
| |
| logging.info("[Salia Ultralytics] Selected model_name: %s", model_name) |
| logging.info("[Salia Ultralytics] Resolved model_path: %s", model_path) |
| if matches and len(matches) > 1: |
| logging.warning("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):") |
| for p in matches: |
| logging.warning(" - %s", p) |
| logging.info("[Salia Ultralytics] Model load log file: %s", _MODEL_LOAD_LOG_PATH) |
|
|
| |
| try: |
| ts = datetime.now().isoformat(timespec="seconds") |
| exists = os.path.isfile(model_path) |
| size = os.path.getsize(model_path) if exists else -1 |
|
|
| log_dir = os.path.dirname(_MODEL_LOAD_LOG_PATH) |
| if log_dir: |
| os.makedirs(log_dir, exist_ok=True) |
|
|
| with open(_MODEL_LOAD_LOG_PATH, "a", encoding="utf-8") as f: |
| f.write(f"{ts}\t{model_name}\t{model_path}\texists={exists}\tsize={size}\n") |
| if matches and len(matches) > 1: |
| for p in matches: |
| f.write(f"{ts}\tmatch\t{p}\n") |
| except Exception as e: |
| logging.warning("[Salia Ultralytics] Failed to write model-load log to %s: %s", _MODEL_LOAD_LOG_PATH, e) |
|
|
|
|
| def _load_whitelist(filepath: str | None) -> set[str]: |
| if not filepath: |
| return set() |
| try: |
| approved: set[str] = set() |
| with open(filepath, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith("#"): |
| approved.add(os.path.basename(line)) |
| return approved |
| except Exception: |
| return set() |
|
|
|
|
| _MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH) |
|
|
|
|
| def _torch_load_wrapper(*args, **kwargs): |
| """Try safe load first; if it fails due to weights-only restrictions, allow fallback if whitelisted.""" |
| filename = None |
| if args and isinstance(args[0], str): |
| filename = os.path.basename(args[0]) |
| elif isinstance(kwargs.get("f"), str): |
| filename = os.path.basename(kwargs["f"]) |
|
|
| try: |
| return _ORIG_TORCH_LOAD(*args, **kwargs) |
| except pickle.UnpicklingError as e: |
| msg = str(e) |
| |
| maybe_weights_only_error = ( |
| "Weights only load failed" in msg |
| or "Unsupported global" in msg |
| or "disallowed" in msg |
| or "not allowed" in msg |
| or "getattr" in msg |
| ) |
|
|
| if not maybe_weights_only_error: |
| raise |
|
|
| |
| global _MODEL_WHITELIST |
| _MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH) |
|
|
| if filename and filename in _MODEL_WHITELIST: |
| logging.warning( |
| "[Salia Ultralytics] Safe torch.load failed for '%s'. Retrying with weights_only=False because it's whitelisted (%s).", |
| filename, |
| _WHITELIST_PATH, |
| ) |
| retry_kwargs = dict(kwargs) |
| retry_kwargs["weights_only"] = False |
| return _ORIG_TORCH_LOAD(*args, **retry_kwargs) |
|
|
| logging.error( |
| "[Salia Ultralytics] Blocked unsafe model load for '%s'.\n" |
| "Safe loading failed and the file is not whitelisted.\n" |
| "If you TRUST this model, add its base name to: %s", |
| filename or "[unknown]", |
| _WHITELIST_PATH or "[whitelist path unavailable]", |
| ) |
| raise |
|
|
|
|
| @contextmanager |
| def _patched_torch_load_for_ultralytics(): |
| """Patch torch.load only while ultralytics loads a checkpoint.""" |
| |
| if not hasattr(torch.serialization, "safe_globals"): |
| yield |
| return |
|
|
| prev = torch.load |
| torch.load = _torch_load_wrapper |
| try: |
| yield |
| finally: |
| torch.load = prev |
|
|
|
|
| def _load_yolo(model_path: str): |
| """Load an Ultralytics YOLO model (with optional safe-load fallback).""" |
| try: |
| from ultralytics import YOLO |
| except Exception as e: |
| raise ImportError( |
| "[Salia Ultralytics] ultralytics is not installed. Install it in your ComfyUI env, e.g.:\n" |
| "pip install ultralytics" |
| ) from e |
|
|
| with _patched_torch_load_for_ultralytics(): |
| return YOLO(model_path) |
|
|
|
|
| |
| |
| |
|
|
| def _tensor2np_rgb(image: torch.Tensor) -> np.ndarray: |
| """Convert a ComfyUI IMAGE tensor to a uint8 RGB numpy image.""" |
| |
| if not isinstance(image, torch.Tensor): |
| raise TypeError(f"Expected torch.Tensor, got {type(image)}") |
|
|
| if image.dim() == 4: |
| img = image[0] |
| else: |
| img = image |
|
|
| img = img.detach() |
| if img.is_cuda: |
| img = img.cpu() |
|
|
| img = img.clamp(0, 1).numpy() |
| if img.shape[-1] == 1: |
| img = np.repeat(img, 3, axis=-1) |
|
|
| img_u8 = (img * 255.0).round().astype(np.uint8) |
| return img_u8 |
|
|
|
|
| def tensor2pil(image: torch.Tensor) -> Image.Image: |
| return Image.fromarray(_tensor2np_rgb(image)) |
|
|
|
|
| def make_crop_region(w: int, h: int, bbox_xyxy, crop_factor: float, crop_min_size: int | None = None): |
| x1, y1, x2, y2 = [float(v) for v in bbox_xyxy] |
| bbox_w = max(1.0, x2 - x1) |
| bbox_h = max(1.0, y2 - y1) |
|
|
| crop_w = bbox_w * float(crop_factor) |
| crop_h = bbox_h * float(crop_factor) |
|
|
| if crop_min_size is not None: |
| crop_w = max(crop_w, float(crop_min_size)) |
| crop_h = max(crop_h, float(crop_min_size)) |
|
|
| cx = (x1 + x2) / 2.0 |
| cy = (y1 + y2) / 2.0 |
|
|
| rx1 = int(round(cx - crop_w / 2.0)) |
| ry1 = int(round(cy - crop_h / 2.0)) |
| rx2 = int(round(cx + crop_w / 2.0)) |
| ry2 = int(round(cy + crop_h / 2.0)) |
|
|
| rx1 = max(0, min(w - 1, rx1)) |
| ry1 = max(0, min(h - 1, ry1)) |
| rx2 = max(rx1 + 1, min(w, rx2)) |
| ry2 = max(ry1 + 1, min(h, ry2)) |
|
|
| return (rx1, ry1, rx2, ry2) |
|
|
|
|
| def crop_image(image: torch.Tensor, crop_region): |
| x1, y1, x2, y2 = crop_region |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
| if image.dim() == 4: |
| return image[:, y1:y2, x1:x2, :] |
| if image.dim() == 3: |
| return image[y1:y2, x1:x2, :] |
| raise ValueError(f"Unexpected image tensor shape: {tuple(image.shape)}") |
|
|
|
|
| def crop_ndarray2(arr: np.ndarray, crop_region): |
| x1, y1, x2, y2 = crop_region |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
| return arr[y1:y2, x1:x2] |
|
|
|
|
| def dilate_masks(segmasks, dilation: int): |
| if dilation <= 0: |
| return segmasks |
| if cv2 is None: |
| raise ImportError( |
| "[Salia Ultralytics] opencv-python is required for mask dilation but cv2 could not be imported.\n" |
| "Install: pip install opencv-python-headless" |
| ) |
|
|
| k = int(dilation) |
| ksize = k * 2 + 1 |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) |
|
|
| out = [] |
| for bbox, mask, conf in segmasks: |
| m = (mask > 0.5).astype(np.uint8) * 255 |
| m = cv2.dilate(m, kernel, iterations=1) |
| out.append((bbox, (m > 0).astype(np.float32), conf)) |
| return out |
|
|
|
|
| def combine_masks(segmasks, out_shape_hw: tuple[int, int] | None = None) -> torch.Tensor: |
| if not segmasks: |
| if out_shape_hw is None: |
| return torch.zeros((1, 1, 1), dtype=torch.float32) |
| h, w = out_shape_hw |
| return torch.zeros((1, h, w), dtype=torch.float32) |
|
|
| base = segmasks[0][1] |
| combined = np.zeros_like(base, dtype=np.float32) |
| for _, m, _ in segmasks: |
| combined = np.maximum(combined, m.astype(np.float32)) |
| return torch.from_numpy(combined).unsqueeze(0) |
|
|
|
|
| |
| |
| |
|
|
| SEG = namedtuple( |
| "SEG", |
| [ |
| "cropped_image", |
| "cropped_mask", |
| "confidence", |
| "crop_region", |
| "bbox", |
| "label", |
| "control_net_wrapper", |
| ], |
| defaults=[None], |
| ) |
|
|
|
|
| class NO_BBOX_DETECTOR: |
| pass |
|
|
|
|
| class NO_SEGM_DETECTOR: |
| pass |
|
|
|
|
| def _create_segmasks(results): |
| |
| bboxes = results[1] |
| segms = results[2] |
| confs = results[3] |
|
|
| out = [] |
| for i in range(len(segms)): |
| out.append((bboxes[i], segms[i].astype(np.float32), confs[i])) |
| return out |
|
|
|
|
| def _inference_bbox(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""): |
| pred = model(image_pil, conf=confidence, device=device) |
|
|
| bboxes = pred[0].boxes.xyxy.cpu().numpy() |
| if bboxes.shape[0] == 0: |
| return [[], [], [], []] |
|
|
| |
| np_img = np.array(image_pil) |
| if np_img.ndim == 2: |
| h, w = np_img.shape |
| else: |
| h, w = np_img.shape[0], np_img.shape[1] |
|
|
| segms = [] |
| for x0, y0, x1, y1 in bboxes: |
| m = np.zeros((h, w), dtype=np.uint8) |
| x0i, y0i, x1i, y1i = int(x0), int(y0), int(x1), int(y1) |
| x0i = max(0, min(w - 1, x0i)) |
| x1i = max(0, min(w, x1i)) |
| y0i = max(0, min(h - 1, y0i)) |
| y1i = max(0, min(h, y1i)) |
| if cv2 is not None: |
| cv2.rectangle(m, (x0i, y0i), (x1i, y1i), 255, -1) |
| else: |
| m[y0i:y1i, x0i:x1i] = 255 |
| segms.append((m > 0)) |
|
|
| labels = [] |
| confs = [] |
| for i in range(len(bboxes)): |
| labels.append(pred[0].names[int(pred[0].boxes[i].cls.item())]) |
| confs.append(pred[0].boxes[i].conf.detach().cpu().numpy()) |
|
|
| return [labels, list(bboxes), segms, confs] |
|
|
|
|
| def _inference_segm(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""): |
| pred = model(image_pil, conf=confidence, device=device) |
|
|
| bboxes = pred[0].boxes.xyxy.cpu().numpy() |
| if bboxes.shape[0] == 0: |
| return [[], [], [], []] |
|
|
| if pred[0].masks is None or pred[0].masks.data is None: |
| |
| return _inference_bbox(model, image_pil, confidence=confidence, device=device) |
|
|
| segms = pred[0].masks.data.detach().cpu().numpy() |
|
|
| |
| h_orig = image_pil.size[1] |
| w_orig = image_pil.size[0] |
|
|
| results = [[], [], [], []] |
|
|
| for i in range(len(bboxes)): |
| results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())]) |
| results[1].append(bboxes[i]) |
|
|
| mask = torch.from_numpy(segms[i]).float() |
| mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(h_orig, w_orig), mode="bilinear", align_corners=False) |
| mask = mask.squeeze(0).squeeze(0) |
|
|
| results[2].append(mask.numpy()) |
| results[3].append(pred[0].boxes[i].conf.detach().cpu().numpy()) |
|
|
| return results |
|
|
|
|
| class SaliaUltraBBoxDetector: |
| def __init__(self, model): |
| self.model = model |
|
|
| def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): |
| drop_size = max(int(drop_size), 1) |
| detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold)) |
| segmasks = _create_segmasks(detected) |
|
|
| if int(dilation) > 0: |
| segmasks = dilate_masks(segmasks, int(dilation)) |
|
|
| items = [] |
| h = image.shape[1] |
| w = image.shape[2] |
|
|
| for (bbox, mask, conf), label in zip(segmasks, detected[0]): |
| x1, y1, x2, y2 = bbox |
| if (x2 - x1) > drop_size and (y2 - y1) > drop_size: |
| crop_region = make_crop_region(w, h, bbox, float(crop_factor)) |
|
|
| if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"): |
| crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region) |
|
|
| cropped_image = crop_image(image, crop_region) |
| cropped_mask = crop_ndarray2(mask, crop_region) |
|
|
| items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None)) |
|
|
| segs = (image.shape[1], image.shape[2]), items |
|
|
| if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): |
| segs = detailer_hook.post_detection(segs) |
|
|
| return segs |
|
|
| def detect_combined(self, image, threshold, dilation): |
| detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold)) |
| segmasks = _create_segmasks(detected) |
| if int(dilation) > 0: |
| segmasks = dilate_masks(segmasks, int(dilation)) |
| return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2])) |
|
|
| def setAux(self, x): |
| pass |
|
|
|
|
| class SaliaUltraSegmDetector: |
| def __init__(self, model): |
| self.model = model |
|
|
| def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): |
| drop_size = max(int(drop_size), 1) |
| detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold)) |
| segmasks = _create_segmasks(detected) |
|
|
| if int(dilation) > 0: |
| segmasks = dilate_masks(segmasks, int(dilation)) |
|
|
| items = [] |
| h = image.shape[1] |
| w = image.shape[2] |
|
|
| for (bbox, mask, conf), label in zip(segmasks, detected[0]): |
| x1, y1, x2, y2 = bbox |
| if (x2 - x1) > drop_size and (y2 - y1) > drop_size: |
| crop_region = make_crop_region(w, h, bbox, float(crop_factor)) |
|
|
| if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"): |
| crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region) |
|
|
| cropped_image = crop_image(image, crop_region) |
| cropped_mask = crop_ndarray2(mask, crop_region) |
|
|
| items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None)) |
|
|
| segs = (image.shape[1], image.shape[2]), items |
|
|
| if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): |
| segs = detailer_hook.post_detection(segs) |
|
|
| return segs |
|
|
| def detect_combined(self, image, threshold, dilation): |
| detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold)) |
| segmasks = _create_segmasks(detected) |
| if int(dilation) > 0: |
| segmasks = dilate_masks(segmasks, int(dilation)) |
| return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2])) |
|
|
| def setAux(self, x): |
| pass |
|
|
|
|
| |
| |
| |
|
|
| class SaliaUltralyticsDetectorProvider2: |
| """Load an Ultralytics `.pt` model and provide Impact-compatible detectors.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| bboxs = ["bbox/" + x for x in folder_paths.get_filename_list("ultralytics_bbox")] |
| segms = ["segm/" + x for x in folder_paths.get_filename_list("ultralytics_segm")] |
| return {"required": {"model_name": (bboxs + segms,)}} |
|
|
| RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR") |
| FUNCTION = "doit" |
| CATEGORY = "Salia/Detectors" |
|
|
| def doit(self, model_name: str): |
| |
| model_path = folder_paths.get_full_path("ultralytics", model_name) |
|
|
| if model_path is None: |
| if model_name.startswith("bbox/"): |
| model_path = folder_paths.get_full_path("ultralytics_bbox", model_name[5:]) |
| elif model_name.startswith("segm/"): |
| model_path = folder_paths.get_full_path("ultralytics_segm", model_name[5:]) |
|
|
| if model_path is None: |
| cands = [] |
| try: |
| cands.extend(folder_paths.get_folder_paths("ultralytics")) |
| if model_name.startswith("bbox/"): |
| cands.extend(folder_paths.get_folder_paths("ultralytics_bbox")) |
| elif model_name.startswith("segm/"): |
| cands.extend(folder_paths.get_folder_paths("ultralytics_segm")) |
| except Exception: |
| pass |
|
|
| formatted = "\n\t".join(cands) |
| raise ValueError( |
| f"[Salia Ultralytics] model file '{model_name}' was not found.\n" |
| f"Searched these folders:\n\t{formatted}\n" |
| f"Tip: put bbox models in 'models/ultralytics/bbox' or segm models in 'models/ultralytics/segm'." |
| ) |
|
|
| |
| matches = _find_all_model_paths(model_name) |
| _log_selected_model(model_name, os.path.abspath(model_path), matches) |
|
|
| model = _load_yolo(model_path) |
|
|
| if model_name.startswith("bbox/"): |
| return SaliaUltraBBoxDetector(model), NO_SEGM_DETECTOR() |
| else: |
| return SaliaUltraBBoxDetector(model), SaliaUltraSegmDetector(model) |
|
|
|
|
| NODE_CLASS_MAPPINGS = { |
| "SaliaUltralyticsDetectorProvider2": SaliaUltralyticsDetectorProvider2, |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "SaliaUltralyticsDetectorProvider2": "Salia Ultralytics Detector 2 (Salia)", |
| } |
|
|