| """ |
| Wrapper class to call the stablediffusion.cpp shared library for GGUF support |
| """ |
|
|
| import ctypes |
| import platform |
| from ctypes import ( |
| POINTER, |
| c_bool, |
| c_char_p, |
| c_float, |
| c_int, |
| c_int64, |
| c_void_p, |
| ) |
| from dataclasses import dataclass |
| from os import path |
| from typing import List, Any |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| from backend.gguf.sdcpp_types import ( |
| RngType, |
| SampleMethod, |
| Schedule, |
| SDCPPLogLevel, |
| SDImage, |
| SdType, |
| ) |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| model_path: str = "" |
| clip_l_path: str = "" |
| t5xxl_path: str = "" |
| diffusion_model_path: str = "" |
| vae_path: str = "" |
| taesd_path: str = "" |
| control_net_path: str = "" |
| lora_model_dir: str = "" |
| embed_dir: str = "" |
| stacked_id_embed_dir: str = "" |
| vae_decode_only: bool = True |
| vae_tiling: bool = False |
| free_params_immediately: bool = False |
| n_threads: int = 4 |
| wtype: SdType = SdType.SD_TYPE_Q4_0 |
| rng_type: RngType = RngType.CUDA_RNG |
| schedule: Schedule = Schedule.DEFAULT |
| keep_clip_on_cpu: bool = False |
| keep_control_net_cpu: bool = False |
| keep_vae_on_cpu: bool = False |
|
|
|
|
| @dataclass |
| class Txt2ImgConfig: |
| prompt: str = "a man wearing sun glasses, highly detailed" |
| negative_prompt: str = "" |
| clip_skip: int = -1 |
| cfg_scale: float = 2.0 |
| guidance: float = 3.5 |
| width: int = 512 |
| height: int = 512 |
| sample_method: SampleMethod = SampleMethod.EULER_A |
| sample_steps: int = 1 |
| seed: int = -1 |
| batch_count: int = 2 |
| control_cond: Image = None |
| control_strength: float = 0.90 |
| style_strength: float = 0.5 |
| normalize_input: bool = False |
| input_id_images_path: bytes = b"" |
|
|
|
|
| class GGUFDiffusion: |
| """GGUF Diffusion |
| To support GGUF diffusion model based on stablediffusion.cpp |
| https://github.com/ggerganov/ggml/blob/master/docs/gguf.md |
| Implmented based on stablediffusion.h |
| """ |
|
|
| def __init__( |
| self, |
| libpath: str, |
| config: ModelConfig, |
| logging_enabled: bool = False, |
| ): |
| sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath) |
| try: |
| self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path) |
| except OSError as e: |
| print(f"Failed to load library {sdcpp_shared_lib_path}") |
| raise ValueError(f"Error: {e}") |
|
|
| if not config.clip_l_path or not path.exists(config.clip_l_path): |
| raise ValueError( |
| "CLIP model file not found,please check readme.md for GGUF model usage" |
| ) |
|
|
| if not config.t5xxl_path or not path.exists(config.t5xxl_path): |
| raise ValueError( |
| "T5XXL model file not found,please check readme.md for GGUF model usage" |
| ) |
|
|
| if not config.diffusion_model_path or not path.exists( |
| config.diffusion_model_path |
| ): |
| raise ValueError( |
| "Diffusion model file not found,please check readme.md for GGUF model usage" |
| ) |
|
|
| if not config.vae_path or not path.exists(config.vae_path): |
| raise ValueError( |
| "VAE model file not found,please check readme.md for GGUF model usage" |
| ) |
|
|
| self.model_config = config |
|
|
| self.libsdcpp.new_sd_ctx.argtypes = [ |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_char_p, |
| c_bool, |
| c_bool, |
| c_bool, |
| c_int, |
| SdType, |
| RngType, |
| Schedule, |
| c_bool, |
| c_bool, |
| c_bool, |
| ] |
|
|
| self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p) |
|
|
| self.sd_ctx = self.libsdcpp.new_sd_ctx( |
| self._str_to_bytes(self.model_config.model_path), |
| self._str_to_bytes(self.model_config.clip_l_path), |
| self._str_to_bytes(self.model_config.t5xxl_path), |
| self._str_to_bytes(self.model_config.diffusion_model_path), |
| self._str_to_bytes(self.model_config.vae_path), |
| self._str_to_bytes(self.model_config.taesd_path), |
| self._str_to_bytes(self.model_config.control_net_path), |
| self._str_to_bytes(self.model_config.lora_model_dir), |
| self._str_to_bytes(self.model_config.embed_dir), |
| self._str_to_bytes(self.model_config.stacked_id_embed_dir), |
| self.model_config.vae_decode_only, |
| self.model_config.vae_tiling, |
| self.model_config.free_params_immediately, |
| self.model_config.n_threads, |
| self.model_config.wtype, |
| self.model_config.rng_type, |
| self.model_config.schedule, |
| self.model_config.keep_clip_on_cpu, |
| self.model_config.keep_control_net_cpu, |
| self.model_config.keep_vae_on_cpu, |
| ) |
|
|
| if logging_enabled: |
| self._set_logcallback() |
|
|
| def _set_logcallback(self): |
| print("Setting logging callback") |
| |
| SdLogCallbackType = ctypes.CFUNCTYPE( |
| None, |
| SDCPPLogLevel, |
| ctypes.c_char_p, |
| ctypes.c_void_p, |
| ) |
|
|
| self.libsdcpp.sd_set_log_callback.argtypes = [ |
| SdLogCallbackType, |
| ctypes.c_void_p, |
| ] |
| self.libsdcpp.sd_set_log_callback.restype = None |
| |
| self.c_log_callback = SdLogCallbackType( |
| self.log_callback |
| ) |
| self.libsdcpp.sd_set_log_callback(self.c_log_callback, None) |
|
|
| def _get_sdcpp_shared_lib_path( |
| self, |
| root_path: str, |
| ) -> str: |
| system_name = platform.system() |
| print(f"GGUF Diffusion on {system_name}") |
| lib_name = "stable-diffusion.dll" |
| sdcpp_lib_path = "" |
|
|
| if system_name == "Windows": |
| sdcpp_lib_path = path.join(root_path, lib_name) |
| elif system_name == "Linux": |
| lib_name = "libstable-diffusion.so" |
| sdcpp_lib_path = path.join(root_path, lib_name) |
| elif system_name == "Darwin": |
| lib_name = "libstable-diffusion.dylib" |
| sdcpp_lib_path = path.join(root_path, lib_name) |
| else: |
| print("Unknown platform.") |
|
|
| return sdcpp_lib_path |
|
|
| @staticmethod |
| def log_callback( |
| level, |
| text, |
| data, |
| ): |
| print(f"{text.decode('utf-8')}", end="") |
|
|
| def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes: |
| if in_str: |
| return in_str.encode(encoding) |
| else: |
| return b"" |
|
|
| def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]: |
| self.libsdcpp.txt2img.restype = POINTER(SDImage) |
| self.libsdcpp.txt2img.argtypes = [ |
| c_void_p, |
| c_char_p, |
| c_char_p, |
| c_int, |
| c_float, |
| c_float, |
| c_int, |
| c_int, |
| SampleMethod, |
| c_int, |
| c_int64, |
| c_int, |
| POINTER(SDImage), |
| c_float, |
| c_float, |
| c_bool, |
| c_char_p, |
| ] |
|
|
| image_buffer = self.libsdcpp.txt2img( |
| self.sd_ctx, |
| self._str_to_bytes(txt2img_cfg.prompt), |
| self._str_to_bytes(txt2img_cfg.negative_prompt), |
| txt2img_cfg.clip_skip, |
| txt2img_cfg.cfg_scale, |
| txt2img_cfg.guidance, |
| txt2img_cfg.width, |
| txt2img_cfg.height, |
| txt2img_cfg.sample_method, |
| txt2img_cfg.sample_steps, |
| txt2img_cfg.seed, |
| txt2img_cfg.batch_count, |
| txt2img_cfg.control_cond, |
| txt2img_cfg.control_strength, |
| txt2img_cfg.style_strength, |
| txt2img_cfg.normalize_input, |
| txt2img_cfg.input_id_images_path, |
| ) |
|
|
| images = self._get_sd_images_from_buffer( |
| image_buffer, |
| txt2img_cfg.batch_count, |
| ) |
|
|
| return images |
|
|
| def _get_sd_images_from_buffer( |
| self, |
| image_buffer: Any, |
| batch_count: int, |
| ) -> List[Any]: |
| images = [] |
| if image_buffer: |
| for i in range(batch_count): |
| image = image_buffer[i] |
| print( |
| f"Generated image: {image.width}x{image.height} with {image.channel} channels" |
| ) |
|
|
| width = image.width |
| height = image.height |
| channels = image.channel |
| pixel_data = np.ctypeslib.as_array( |
| image.data, shape=(height, width, channels) |
| ) |
|
|
| if channels == 1: |
| pil_image = Image.fromarray(pixel_data.squeeze(), mode="L") |
| elif channels == 3: |
| pil_image = Image.fromarray(pixel_data, mode="RGB") |
| elif channels == 4: |
| pil_image = Image.fromarray(pixel_data, mode="RGBA") |
| else: |
| raise ValueError(f"Unsupported number of channels: {channels}") |
|
|
| images.append(pil_image) |
| return images |
|
|
| def terminate(self): |
| if self.libsdcpp: |
| if self.sd_ctx: |
| self.libsdcpp.free_sd_ctx.argtypes = [c_void_p] |
| self.libsdcpp.free_sd_ctx.restype = None |
| self.libsdcpp.free_sd_ctx(self.sd_ctx) |
| del self.sd_ctx |
| self.sd_ctx = None |
| del self.libsdcpp |
| self.libsdcpp = None |
|
|