| """
|
| Wrapper class to call the stablediffusion.cpp shared library for GGUF support
|
| """
|
|
|
| import ctypes
|
| import platform
|
| from ctypes import (
|
| POINTER,
|
| c_bool,
|
| c_char_p,
|
| c_float,
|
| c_int,
|
| c_int64,
|
| c_void_p,
|
| )
|
| from dataclasses import dataclass
|
| from os import path
|
| from typing import List, Any
|
|
|
| import numpy as np
|
| from PIL import Image
|
|
|
| from backend.gguf.sdcpp_types import (
|
| RngType,
|
| SampleMethod,
|
| Schedule,
|
| SDCPPLogLevel,
|
| SDImage,
|
| SdType,
|
| )
|
|
|
|
|
| @dataclass
|
| class ModelConfig:
|
| model_path: str = ""
|
| clip_l_path: str = ""
|
| t5xxl_path: str = ""
|
| diffusion_model_path: str = ""
|
| vae_path: str = ""
|
| taesd_path: str = ""
|
| control_net_path: str = ""
|
| lora_model_dir: str = ""
|
| embed_dir: str = ""
|
| stacked_id_embed_dir: str = ""
|
| vae_decode_only: bool = True
|
| vae_tiling: bool = False
|
| free_params_immediately: bool = False
|
| n_threads: int = 4
|
| wtype: SdType = SdType.SD_TYPE_Q4_0
|
| rng_type: RngType = RngType.CUDA_RNG
|
| schedule: Schedule = Schedule.DEFAULT
|
| keep_clip_on_cpu: bool = False
|
| keep_control_net_cpu: bool = False
|
| keep_vae_on_cpu: bool = False
|
|
|
|
|
| @dataclass
|
| class Txt2ImgConfig:
|
| prompt: str = "a man wearing sun glasses, highly detailed"
|
| negative_prompt: str = ""
|
| clip_skip: int = -1
|
| cfg_scale: float = 2.0
|
| guidance: float = 3.5
|
| width: int = 512
|
| height: int = 512
|
| sample_method: SampleMethod = SampleMethod.EULER_A
|
| sample_steps: int = 1
|
| seed: int = -1
|
| batch_count: int = 2
|
| control_cond: Image = None
|
| control_strength: float = 0.90
|
| style_strength: float = 0.5
|
| normalize_input: bool = False
|
| input_id_images_path: bytes = b""
|
|
|
|
|
| class GGUFDiffusion:
|
| """GGUF Diffusion
|
| To support GGUF diffusion model based on stablediffusion.cpp
|
| https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
|
| Implmented based on stablediffusion.h
|
| """
|
|
|
| def __init__(
|
| self,
|
| libpath: str,
|
| config: ModelConfig,
|
| logging_enabled: bool = False,
|
| ):
|
| sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath)
|
| try:
|
| self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path)
|
| except OSError as e:
|
| print(f"Failed to load library {sdcpp_shared_lib_path}")
|
| raise ValueError(f"Error: {e}")
|
|
|
| if not config.clip_l_path or not path.exists(config.clip_l_path):
|
| raise ValueError(
|
| "CLIP model file not found,please check readme.md for GGUF model usage"
|
| )
|
|
|
| if not config.t5xxl_path or not path.exists(config.t5xxl_path):
|
| raise ValueError(
|
| "T5XXL model file not found,please check readme.md for GGUF model usage"
|
| )
|
|
|
| if not config.diffusion_model_path or not path.exists(
|
| config.diffusion_model_path
|
| ):
|
| raise ValueError(
|
| "Diffusion model file not found,please check readme.md for GGUF model usage"
|
| )
|
|
|
| if not config.vae_path or not path.exists(config.vae_path):
|
| raise ValueError(
|
| "VAE model file not found,please check readme.md for GGUF model usage"
|
| )
|
|
|
| self.model_config = config
|
|
|
| self.libsdcpp.new_sd_ctx.argtypes = [
|
| c_char_p,
|
| c_char_p,
|
| c_char_p,
|
| c_char_p,
|
| c_char_p,
|
| c_char_p,
|
| c_char_p,
|
| c_char_p,
|
| c_char_p,
|
| c_char_p,
|
| c_bool,
|
| c_bool,
|
| c_bool,
|
| c_int,
|
| SdType,
|
| RngType,
|
| Schedule,
|
| c_bool,
|
| c_bool,
|
| c_bool,
|
| ]
|
|
|
| self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p)
|
|
|
| self.sd_ctx = self.libsdcpp.new_sd_ctx(
|
| self._str_to_bytes(self.model_config.model_path),
|
| self._str_to_bytes(self.model_config.clip_l_path),
|
| self._str_to_bytes(self.model_config.t5xxl_path),
|
| self._str_to_bytes(self.model_config.diffusion_model_path),
|
| self._str_to_bytes(self.model_config.vae_path),
|
| self._str_to_bytes(self.model_config.taesd_path),
|
| self._str_to_bytes(self.model_config.control_net_path),
|
| self._str_to_bytes(self.model_config.lora_model_dir),
|
| self._str_to_bytes(self.model_config.embed_dir),
|
| self._str_to_bytes(self.model_config.stacked_id_embed_dir),
|
| self.model_config.vae_decode_only,
|
| self.model_config.vae_tiling,
|
| self.model_config.free_params_immediately,
|
| self.model_config.n_threads,
|
| self.model_config.wtype,
|
| self.model_config.rng_type,
|
| self.model_config.schedule,
|
| self.model_config.keep_clip_on_cpu,
|
| self.model_config.keep_control_net_cpu,
|
| self.model_config.keep_vae_on_cpu,
|
| )
|
|
|
| if logging_enabled:
|
| self._set_logcallback()
|
|
|
| def _set_logcallback(self):
|
| print("Setting logging callback")
|
|
|
| SdLogCallbackType = ctypes.CFUNCTYPE(
|
| None,
|
| SDCPPLogLevel,
|
| ctypes.c_char_p,
|
| ctypes.c_void_p,
|
| )
|
|
|
| self.libsdcpp.sd_set_log_callback.argtypes = [
|
| SdLogCallbackType,
|
| ctypes.c_void_p,
|
| ]
|
| self.libsdcpp.sd_set_log_callback.restype = None
|
|
|
| self.c_log_callback = SdLogCallbackType(
|
| self.log_callback
|
| )
|
| self.libsdcpp.sd_set_log_callback(self.c_log_callback, None)
|
|
|
| def _get_sdcpp_shared_lib_path(
|
| self,
|
| root_path: str,
|
| ) -> str:
|
| system_name = platform.system()
|
| print(f"GGUF Diffusion on {system_name}")
|
| lib_name = "stable-diffusion.dll"
|
| sdcpp_lib_path = ""
|
|
|
| if system_name == "Windows":
|
| sdcpp_lib_path = path.join(root_path, lib_name)
|
| elif system_name == "Linux":
|
| lib_name = "libstable-diffusion.so"
|
| sdcpp_lib_path = path.join(root_path, lib_name)
|
| elif system_name == "Darwin":
|
| lib_name = "libstable-diffusion.dylib"
|
| sdcpp_lib_path = path.join(root_path, lib_name)
|
| else:
|
| print("Unknown platform.")
|
|
|
| return sdcpp_lib_path
|
|
|
| @staticmethod
|
| def log_callback(
|
| level,
|
| text,
|
| data,
|
| ):
|
| print(f"{text.decode('utf-8')}", end="")
|
|
|
| def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes:
|
| if in_str:
|
| return in_str.encode(encoding)
|
| else:
|
| return b""
|
|
|
| def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]:
|
| self.libsdcpp.txt2img.restype = POINTER(SDImage)
|
| self.libsdcpp.txt2img.argtypes = [
|
| c_void_p,
|
| c_char_p,
|
| c_char_p,
|
| c_int,
|
| c_float,
|
| c_float,
|
| c_int,
|
| c_int,
|
| SampleMethod,
|
| c_int,
|
| c_int64,
|
| c_int,
|
| POINTER(SDImage),
|
| c_float,
|
| c_float,
|
| c_bool,
|
| c_char_p,
|
| ]
|
|
|
| image_buffer = self.libsdcpp.txt2img(
|
| self.sd_ctx,
|
| self._str_to_bytes(txt2img_cfg.prompt),
|
| self._str_to_bytes(txt2img_cfg.negative_prompt),
|
| txt2img_cfg.clip_skip,
|
| txt2img_cfg.cfg_scale,
|
| txt2img_cfg.guidance,
|
| txt2img_cfg.width,
|
| txt2img_cfg.height,
|
| txt2img_cfg.sample_method,
|
| txt2img_cfg.sample_steps,
|
| txt2img_cfg.seed,
|
| txt2img_cfg.batch_count,
|
| txt2img_cfg.control_cond,
|
| txt2img_cfg.control_strength,
|
| txt2img_cfg.style_strength,
|
| txt2img_cfg.normalize_input,
|
| txt2img_cfg.input_id_images_path,
|
| )
|
|
|
| images = self._get_sd_images_from_buffer(
|
| image_buffer,
|
| txt2img_cfg.batch_count,
|
| )
|
|
|
| return images
|
|
|
| def _get_sd_images_from_buffer(
|
| self,
|
| image_buffer: Any,
|
| batch_count: int,
|
| ) -> List[Any]:
|
| images = []
|
| if image_buffer:
|
| for i in range(batch_count):
|
| image = image_buffer[i]
|
| print(
|
| f"Generated image: {image.width}x{image.height} with {image.channel} channels"
|
| )
|
|
|
| width = image.width
|
| height = image.height
|
| channels = image.channel
|
| pixel_data = np.ctypeslib.as_array(
|
| image.data, shape=(height, width, channels)
|
| )
|
|
|
| if channels == 1:
|
| pil_image = Image.fromarray(pixel_data.squeeze(), mode="L")
|
| elif channels == 3:
|
| pil_image = Image.fromarray(pixel_data, mode="RGB")
|
| elif channels == 4:
|
| pil_image = Image.fromarray(pixel_data, mode="RGBA")
|
| else:
|
| raise ValueError(f"Unsupported number of channels: {channels}")
|
|
|
| images.append(pil_image)
|
| return images
|
|
|
| def terminate(self):
|
| if self.libsdcpp:
|
| if self.sd_ctx:
|
| self.libsdcpp.free_sd_ctx.argtypes = [c_void_p]
|
| self.libsdcpp.free_sd_ctx.restype = None
|
| self.libsdcpp.free_sd_ctx(self.sd_ctx)
|
| del self.sd_ctx
|
| self.sd_ctx = None
|
| del self.libsdcpp
|
| self.libsdcpp = None
|
|
|