| import modules.scripts as scripts
|
| import gradio as gr
|
|
|
| from modules import images
|
| from modules.processing import process_images
|
| from modules.shared import opts
|
| import numpy as np
|
|
|
|
|
| class Script(scripts.Script):
|
|
|
| def title(self):
|
| return "txt2palette"
|
|
|
| def show(self, is_img2img):
|
| return not is_img2img
|
|
|
| def ui(self, is_img2img):
|
| palette_size = gr.Slider(minimum=1, maximum=64, step=1, value=0,
|
| label="Palette size")
|
| method = gr.Radio(choices=['Median cut', 'KMeans'], value='Median cut', label='Palette extraction method')
|
| sort_by = gr.Radio(choices=["luminance", "hue", "saturation", "value", "lightness"], value="luminance", label="Sort colors by")
|
| overwrite = gr.Checkbox(False, label="Overwrite existing files")
|
| return [palette_size, method, sort_by, overwrite]
|
|
|
| def run(self, p, palette_size, method, sort_by, overwrite):
|
| import colorsys
|
| from PIL import Image
|
| try:
|
| from sklearn.cluster import KMeans
|
| except ImportError:
|
| if method == 'KMeans':
|
| print('"sklearn" library is not installed, switching the extraction method to Median cut.')
|
| method = "Median cut"
|
|
|
| class Color:
|
| luminance_weights = np.array([0.2126, 0.7152, 0.0722])
|
|
|
| def __init__(self, RGB, frequency):
|
| self.rgb = tuple([c for c in RGB])
|
| self.freq = frequency
|
|
|
| def display(self, w=50, h=50):
|
| """
|
| Displays the represented color in a w x h window.
|
| :param w: width in pixels
|
| :param h: height in pixels
|
| """
|
|
|
| img = Image.new("RGB", size=(w, h), color=self.rgb)
|
| img.show()
|
|
|
| def __lt__(self, other):
|
| return self.freq < other.freq
|
|
|
| def get_colors(self, colorspace="rgb"):
|
| """
|
| Get the color in terms of a colorspace (string).
|
| :param colorspace: rgb/hsv/hls
|
| :return: corresponding color values
|
| """
|
| colors = {"rgb": self.rgb, "hsv": self.hsv, "hls": self.hls}
|
| return colors[colorspace]
|
|
|
| @property
|
| def hsv(self):
|
| return colorsys.rgb_to_hsv(*self.rgb)
|
|
|
| @property
|
| def hls(self):
|
| return colorsys.rgb_to_hls(*self.rgb)
|
|
|
| @property
|
| def luminance(self):
|
| return np.dot(self.luminance_weights, self.rgb)
|
|
|
| class ColorBox:
|
| """
|
| Represents a box in the RGB color space, with associated attributes, used in the Median Cut algorithm.
|
| """
|
| def __init__(self, colors):
|
| """
|
| Initialize with a numpy array of RGB colors.
|
| :param colors: np.ndarray (width * height, 3)
|
| """
|
|
|
| self.colors = colors
|
| self._get_min_max()
|
|
|
| def _get_min_max(self):
|
| min_channel = np.min(self.colors, axis=0)
|
| max_channel = np.max(self.colors, axis=0)
|
|
|
| self.min_channel = min_channel
|
| self.max_channel = max_channel
|
|
|
| def __lt__(self, other):
|
| """
|
| Compare cubes by volume
|
| :param other:
|
| """
|
| return self.size < other.size
|
|
|
| @property
|
| def size(self):
|
| return self.volume
|
|
|
| def _get_dominant_channel(self):
|
| dominant_channel = np.argmax(self.max_channel - self.min_channel)
|
| return dominant_channel
|
|
|
| @property
|
| def average(self):
|
| """
|
| Returns the average color contained in ColorBox
|
| :return: [R, G, B]
|
| """
|
|
|
| return np.mean(self.colors, axis=0)
|
|
|
| @property
|
| def volume(self):
|
| return np.prod(
|
| self.max_channel - self.min_channel,
|
| )
|
|
|
| def split(self):
|
| """
|
| Splits the ColorBox into two ColorBoxes at the median of the dominant color channel.
|
| :return: [ColorBox1, ColorBox2]
|
| """
|
|
|
|
|
| dominant_channel = self._get_dominant_channel()
|
|
|
|
|
| self.colors = self.colors[self.colors[:, dominant_channel].argsort()]
|
|
|
| median_index = len(self.colors) // 2
|
|
|
| return [
|
| ColorBox(self.colors[:median_index]),
|
| ColorBox(self.colors[median_index:]),
|
| ]
|
|
|
| class Palette:
|
| def __init__(self, colors):
|
| """
|
| Initializes a color palette with a list of Color objects.
|
| :param colors: a list of Color-objects
|
| """
|
|
|
| self.colors = colors
|
| self.frequencies = [c.freq for c in colors]
|
| self.number_of_colors = len(colors)
|
|
|
| def get_image(self, w=50, h=50):
|
| img = Image.new("RGB", size=(w * self.number_of_colors, h))
|
| arr = np.asarray(img).copy()
|
| for i in range(self.number_of_colors):
|
| c = self.colors[i]
|
| arr[:, i * h : (i + 1) * h, :] = c.rgb
|
| img = Image.fromarray(arr, "RGB")
|
| return img
|
|
|
| def k_means_extraction(arr, height, width, palette_size):
|
| """
|
| Extracts a color palette using KMeans.
|
| :param arr: pixel array (height, width, 3)
|
| :param height: height
|
| :param width: width
|
| :param palette_size: number of colors
|
| :return: a palette of colors sorted by frequency
|
| """
|
| arr = np.reshape(arr, (width * height, -1))
|
| model = KMeans(n_clusters=palette_size)
|
| labels = model.fit_predict(arr)
|
| palette = np.array(model.cluster_centers_, dtype=int)
|
| color_count = np.bincount(labels)
|
| color_frequency = color_count / float(np.sum(color_count))
|
| colors = []
|
| for color, freq in zip(palette, color_frequency):
|
| colors.append(Color(color, freq))
|
| return colors
|
|
|
| def median_cut_extraction(arr, height, width, palette_size):
|
| """
|
| Extracts a color palette using the median cut algorithm.
|
| :param arr:
|
| :param height:
|
| :param width:
|
| :param palette_size:
|
| :return:
|
| """
|
| arr = arr.reshape((width * height, -1))
|
| c = [ColorBox(arr)]
|
| full_box_size = c[0].size
|
|
|
| while len(c) < palette_size:
|
| largest_c_idx = np.argmax(c)
|
|
|
| c = c[:largest_c_idx] + c[largest_c_idx].split() + c[largest_c_idx + 1 :]
|
| colors = [Color(map(int, box.average), box.size / full_box_size) for box in c]
|
| return colors
|
|
|
| sort_methods = {
|
| "luminance": lambda c: c.luminance,
|
| "hue": lambda c: c.hsv[0],
|
| "saturation": lambda c: c.hsv[1],
|
| "value": lambda c: c.hsv[2],
|
| "lightness": lambda c: c.hls[2],
|
| }
|
|
|
| def extract_colors(image, palette_size=5, resize=True, mode="Median cut", sort_mode=None):
|
| """
|
| Extracts a set of 'palette_size' colors from the given image.
|
| :param image: PIL.Image object of path to Image file
|
| :param palette_size: number of colors to extract
|
| :param resize: whether to resize the image before processing, yielding faster results with lower quality
|
| :param mode: the color quantization algorithm to use. Currently supports K-Means (KM) and Median Cut (MC)
|
| :param sort_mode: sort colors by luminance, or by frequency
|
| :return: a list of the extracted colors
|
| """
|
| if isinstance(image, Image.Image):
|
| img = image
|
| else:
|
| img = Image.open(image)
|
| img = img.convert("RGB")
|
| if resize:
|
| img = img.resize((256, 256))
|
| width, height = img.size
|
| arr = np.asarray(img)
|
|
|
| if mode == "KMeans":
|
| colors = k_means_extraction(arr, height, width, palette_size)
|
| elif mode == "Median cut":
|
| colors = median_cut_extraction(arr, height, width, palette_size)
|
| else:
|
| raise NotImplementedError("Extraction mode not implemented!")
|
|
|
| if sort_mode in sort_methods:
|
| colors.sort(key=sort_methods.get(sort_mode), reverse=False)
|
| else:
|
| raise NotImplementedError("Sorting mode not implemented!")
|
| return Palette(colors)
|
|
|
|
|
| if(not overwrite):
|
| basename = f"_palette_{palette_size}x"
|
| else:
|
| p.do_not_save_samples = True
|
|
|
| proc = process_images(p)
|
|
|
|
|
|
|
| if len(proc.images) > 1:
|
| iter_offset = 1
|
| iter_num = len(proc.images) - 1
|
| else:
|
| iter_offset = 0
|
| iter_num = 1
|
|
|
| for i in range(iter_num):
|
| pal = extract_colors(proc.images[i+iter_offset], palette_size=palette_size, sort_mode=sort_by, mode=method)
|
| proc.images[i+iter_offset] = pal.get_image()
|
|
|
| images.save_image(proc.images[i+iter_offset], p.outpath_samples, basename,
|
| proc.seed + i, proc.prompt, opts.samples_format, info= proc.info, p=p)
|
|
|
| return proc |