| from __future__ import annotations |
|
|
| import importlib |
| import re |
| from functools import lru_cache |
| from pathlib import Path |
|
|
| from modules import extensions, sd_models, shared |
| from modules.paths import data_path, models_path, script_path |
|
|
| ext_path = Path(data_path, "extensions") |
| ext_builtin_path = Path(script_path, "extensions-builtin") |
| controlnet_exists = False |
| controlnet_path = None |
| cn_base_path = "" |
|
|
| for extension in extensions.active(): |
| if not extension.enabled: |
| continue |
| |
| if "sd-webui-controlnet" in extension.name: |
| controlnet_exists = True |
| controlnet_path = Path(extension.path) |
| cn_base_path = ".".join(controlnet_path.parts[-2:]) |
| break |
|
|
| cn_model_module = { |
| "inpaint": "inpaint_global_harmonious", |
| "scribble": "t2ia_sketch_pidi", |
| "lineart": "lineart_coarse", |
| "openpose": "openpose_full", |
| "tile": None, |
| } |
| cn_model_regex = re.compile("|".join(cn_model_module.keys())) |
|
|
|
|
| class ControlNetExt: |
| def __init__(self): |
| self.cn_models = ["None"] |
| self.cn_available = False |
| self.external_cn = None |
|
|
| def init_controlnet(self): |
| import_path = cn_base_path + ".scripts.external_code" |
|
|
| self.external_cn = importlib.import_module(import_path, "external_code") |
| self.cn_available = True |
| models = self.external_cn.get_models() |
| self.cn_models.extend(m for m in models if cn_model_regex.search(m)) |
|
|
| def update_scripts_args( |
| self, |
| p, |
| model: str, |
| module: str | None, |
| weight: float, |
| guidance_start: float, |
| guidance_end: float, |
| ): |
| if (not self.cn_available) or model == "None": |
| return |
|
|
| if module is None: |
| for m, v in cn_model_module.items(): |
| if m in model: |
| module = v |
| break |
|
|
| cn_units = [ |
| self.external_cn.ControlNetUnit( |
| model=model, |
| weight=weight, |
| control_mode=self.external_cn.ControlMode.BALANCED, |
| module=module, |
| guidance_start=guidance_start, |
| guidance_end=guidance_end, |
| pixel_perfect=True, |
| ) |
| ] |
|
|
| self.external_cn.update_cn_script_in_processing(p, cn_units) |
|
|
|
|
| def get_cn_model_dirs() -> list[Path]: |
| cn_model_dir = Path(models_path, "ControlNet") |
| if controlnet_path is not None: |
| cn_model_dir_old = controlnet_path.joinpath("models") |
| else: |
| cn_model_dir_old = None |
| ext_dir1 = shared.opts.data.get("control_net_models_path", "") |
| ext_dir2 = getattr(shared.cmd_opts, "controlnet_dir", "") |
|
|
| dirs = [cn_model_dir] |
| for ext_dir in [cn_model_dir_old, ext_dir1, ext_dir2]: |
| if ext_dir: |
| dirs.append(Path(ext_dir)) |
|
|
| return dirs |
|
|
|
|
| @lru_cache |
| def _get_cn_models() -> list[str]: |
| """ |
| Since we can't import ControlNet, we use a function that does something like |
| controlnet's `list(global_state.cn_models_names.values())`. |
| """ |
| cn_model_exts = (".pt", ".pth", ".ckpt", ".safetensors") |
| dirs = get_cn_model_dirs() |
| name_filter = shared.opts.data.get("control_net_models_name_filter", "") |
| name_filter = name_filter.strip(" ").lower() |
|
|
| model_paths = [] |
|
|
| for base in dirs: |
| if not base.exists(): |
| continue |
|
|
| for p in base.rglob("*"): |
| if ( |
| p.is_file() |
| and p.suffix in cn_model_exts |
| and cn_model_regex.search(p.name) |
| ): |
| if name_filter and name_filter not in p.name.lower(): |
| continue |
| model_paths.append(p) |
| model_paths.sort(key=lambda p: p.name) |
|
|
| models = [] |
| for p in model_paths: |
| model_hash = sd_models.model_hash(p) |
| name = f"{p.stem} [{model_hash}]" |
| models.append(name) |
| return models |
|
|
|
|
| def get_cn_models() -> list[str]: |
| if controlnet_exists: |
| return _get_cn_models() |
| return [] |
|
|