| from __future__ import annotations |
|
|
| from contextlib import contextmanager |
|
|
| from modules import img2img, processing, shared |
|
|
|
|
| def cn_restore_unet_hook(p, cn_latest_network): |
| if cn_latest_network is not None: |
| unet = p.sd_model.model.diffusion_model |
| cn_latest_network.restore(unet) |
|
|
|
|
| class CNHijackRestore: |
| def __init__(self): |
| self.process = hasattr(processing, "__controlnet_original_process_images_inner") |
| self.img2img = hasattr(img2img, "__controlnet_original_process_batch") |
|
|
| def __enter__(self): |
| if self.process: |
| self.orig_process = processing.process_images_inner |
| processing.process_images_inner = getattr( |
| processing, "__controlnet_original_process_images_inner" |
| ) |
| if self.img2img: |
| self.orig_img2img = img2img.process_batch |
| img2img.process_batch = getattr( |
| img2img, "__controlnet_original_process_batch" |
| ) |
|
|
| def __exit__(self, *args, **kwargs): |
| if self.process: |
| processing.process_images_inner = self.orig_process |
| if self.img2img: |
| img2img.process_batch = self.orig_img2img |
|
|
|
|
| @contextmanager |
| def cn_allow_script_control(): |
| orig = False |
| if "control_net_allow_script_control" in shared.opts.data: |
| try: |
| orig = shared.opts.data["control_net_allow_script_control"] |
| shared.opts.data["control_net_allow_script_control"] = True |
| yield |
| finally: |
| shared.opts.data["control_net_allow_script_control"] = orig |
| else: |
| yield |
|
|