dikdimon commited on
Commit
1a5cc4f
·
verified ·
1 Parent(s): db57927

Upload 7 files

Browse files
sec/processing.py ADDED
@@ -0,0 +1,1860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ import sys
7
+ import hashlib
8
+ from dataclasses import dataclass, field
9
+
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image, ImageOps
13
+ import random
14
+ import cv2
15
+ from skimage import exposure
16
+ from typing import Any
17
+
18
+ import modules.sd_hijack
19
+ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
20
+ from modules.rng import slerp # noqa: F401
21
+ from modules.sd_hijack import model_hijack
22
+ from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
23
+ from modules.shared import opts, cmd_opts, state
24
+ import modules.shared as shared
25
+ import modules.paths as paths
26
+ import modules.face_restoration
27
+ import modules.images as images
28
+ import modules.styles
29
+ import modules.sd_models as sd_models
30
+ import modules.sd_vae as sd_vae
31
+ from ldm.data.util import AddMiDaS
32
+ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
33
+
34
+ from einops import repeat, rearrange
35
+ from blendmodes.blend import blendLayers, BlendType
36
+
37
+
38
+ # some of those options should not be changed at all because they would break the model, so I removed them from options.
39
+ opt_C = 4
40
+ opt_f = 8
41
+
42
+
43
+ def setup_color_correction(image):
44
+ logging.info("Calibrating color correction.")
45
+ correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
46
+ return correction_target
47
+
48
+
49
+ def apply_color_correction(correction, original_image):
50
+ logging.info("Applying color correction.")
51
+ image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
52
+ cv2.cvtColor(
53
+ np.asarray(original_image),
54
+ cv2.COLOR_RGB2LAB
55
+ ),
56
+ correction,
57
+ channel_axis=2
58
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
59
+
60
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
61
+
62
+ return image.convert('RGB')
63
+
64
+
65
+ def uncrop(image, dest_size, paste_loc):
66
+ x, y, w, h = paste_loc
67
+ base_image = Image.new('RGBA', dest_size)
68
+ image = images.resize_image(1, image, w, h)
69
+ base_image.paste(image, (x, y))
70
+ image = base_image
71
+
72
+ return image
73
+
74
+
75
+ def apply_overlay(image, paste_loc, overlay):
76
+ if overlay is None:
77
+ return image, image.copy()
78
+
79
+ if paste_loc is not None:
80
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
81
+
82
+ original_denoised_image = image.copy()
83
+
84
+ image = image.convert('RGBA')
85
+ image.alpha_composite(overlay)
86
+ image = image.convert('RGB')
87
+
88
+ return image, original_denoised_image
89
+
90
+ def create_binary_mask(image, round=True):
91
+ if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
92
+ if round:
93
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
94
+ else:
95
+ image = image.split()[-1].convert("L")
96
+ else:
97
+ image = image.convert('L')
98
+ return image
99
+
100
+ def txt2img_image_conditioning(sd_model, x, width, height):
101
+ if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
102
+
103
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
104
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
105
+ image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
106
+
107
+ # Add the fake full 1s mask to the first dimension.
108
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
109
+ image_conditioning = image_conditioning.to(x.dtype)
110
+
111
+ return image_conditioning
112
+
113
+ elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
114
+
115
+ return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
116
+
117
+ else:
118
+ if sd_model.is_sdxl_inpaint:
119
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
120
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
121
+ image_conditioning = images_tensor_to_samples(image_conditioning,
122
+ approximation_indexes.get(opts.sd_vae_encode_method))
123
+
124
+ # Add the fake full 1s mask to the first dimension.
125
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
126
+ image_conditioning = image_conditioning.to(x.dtype)
127
+
128
+ return image_conditioning
129
+
130
+ # Dummy zero conditioning if we're not using inpainting or unclip models.
131
+ # Still takes up a bit of memory, but no encoder call.
132
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
133
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
134
+
135
+
136
+ @dataclass(repr=False)
137
+ class StableDiffusionProcessing:
138
+ sd_model: object = None
139
+ outpath_samples: str = None
140
+ outpath_grids: str = None
141
+ prompt: str = ""
142
+ prompt_for_display: str = None
143
+ negative_prompt: str = ""
144
+ styles: list[str] = None
145
+ seed: int = -1
146
+ subseed: int = -1
147
+ subseed_strength: float = 0
148
+ seed_resize_from_h: int = -1
149
+ seed_resize_from_w: int = -1
150
+ seed_enable_extras: bool = True
151
+ sampler_name: str = None
152
+ scheduler: str = None
153
+ batch_size: int = 1
154
+ n_iter: int = 1
155
+ steps: int = 50
156
+ cfg_scale: float = 7.0
157
+ width: int = 512
158
+ height: int = 512
159
+ restore_faces: bool = None
160
+ tiling: bool = None
161
+ do_not_save_samples: bool = False
162
+ do_not_save_grid: bool = False
163
+ extra_generation_params: dict[str, Any] = None
164
+ overlay_images: list = None
165
+ eta: float = None
166
+ do_not_reload_embeddings: bool = False
167
+ denoising_strength: float = None
168
+ ddim_discretize: str = None
169
+ s_min_uncond: float = None
170
+ s_churn: float = None
171
+ s_tmax: float = None
172
+ s_tmin: float = None
173
+ s_noise: float = None
174
+ override_settings: dict[str, Any] = None
175
+ override_settings_restore_afterwards: bool = True
176
+ sampler_index: int = None
177
+ refiner_checkpoint: str = None
178
+ refiner_switch_at: float = None
179
+ token_merging_ratio = 0
180
+ token_merging_ratio_hr = 0
181
+ disable_extra_networks: bool = False
182
+ firstpass_image: Image = None
183
+
184
+ scripts_value: scripts.ScriptRunner = field(default=None, init=False)
185
+ script_args_value: list = field(default=None, init=False)
186
+ scripts_setup_complete: bool = field(default=False, init=False)
187
+
188
+ cached_uc = [None, None]
189
+ cached_c = [None, None]
190
+
191
+ comments: dict = None
192
+ sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
193
+ is_using_inpainting_conditioning: bool = field(default=False, init=False)
194
+ paste_to: tuple | None = field(default=None, init=False)
195
+
196
+ is_hr_pass: bool = field(default=False, init=False)
197
+
198
+ c: tuple = field(default=None, init=False)
199
+ uc: tuple = field(default=None, init=False)
200
+
201
+ rng: rng.ImageRNG | None = field(default=None, init=False)
202
+ step_multiplier: int = field(default=1, init=False)
203
+ color_corrections: list = field(default=None, init=False)
204
+
205
+ all_prompts: list = field(default=None, init=False)
206
+ all_negative_prompts: list = field(default=None, init=False)
207
+ all_seeds: list = field(default=None, init=False)
208
+ all_subseeds: list = field(default=None, init=False)
209
+ iteration: int = field(default=0, init=False)
210
+ main_prompt: str = field(default=None, init=False)
211
+ main_negative_prompt: str = field(default=None, init=False)
212
+
213
+ prompts: list = field(default=None, init=False)
214
+ negative_prompts: list = field(default=None, init=False)
215
+ seeds: list = field(default=None, init=False)
216
+ subseeds: list = field(default=None, init=False)
217
+ extra_network_data: dict = field(default=None, init=False)
218
+
219
+ user: str = field(default=None, init=False)
220
+
221
+ sd_model_name: str = field(default=None, init=False)
222
+ sd_model_hash: str = field(default=None, init=False)
223
+ sd_vae_name: str = field(default=None, init=False)
224
+ sd_vae_hash: str = field(default=None, init=False)
225
+
226
+ is_api: bool = field(default=False, init=False)
227
+
228
+ def __post_init__(self):
229
+ if self.sampler_index is not None:
230
+ print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
231
+
232
+ self.comments = {}
233
+
234
+ if self.styles is None:
235
+ self.styles = []
236
+
237
+ self.sampler_noise_scheduler_override = None
238
+
239
+ self.extra_generation_params = self.extra_generation_params or {}
240
+ self.override_settings = self.override_settings or {}
241
+ self.script_args = self.script_args or {}
242
+
243
+ self.refiner_checkpoint_info = None
244
+
245
+ if not self.seed_enable_extras:
246
+ self.subseed = -1
247
+ self.subseed_strength = 0
248
+ self.seed_resize_from_h = 0
249
+ self.seed_resize_from_w = 0
250
+
251
+ self.cached_uc = StableDiffusionProcessing.cached_uc
252
+ self.cached_c = StableDiffusionProcessing.cached_c
253
+
254
+ def fill_fields_from_opts(self):
255
+ self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
256
+ self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
257
+ self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
258
+ self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
259
+ self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
260
+
261
+ @property
262
+ def sd_model(self):
263
+ return shared.sd_model
264
+
265
+ @sd_model.setter
266
+ def sd_model(self, value):
267
+ pass
268
+
269
+ @property
270
+ def scripts(self):
271
+ return self.scripts_value
272
+
273
+ @scripts.setter
274
+ def scripts(self, value):
275
+ self.scripts_value = value
276
+
277
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
278
+ self.setup_scripts()
279
+
280
+ @property
281
+ def script_args(self):
282
+ return self.script_args_value
283
+
284
+ @script_args.setter
285
+ def script_args(self, value):
286
+ self.script_args_value = value
287
+
288
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
289
+ self.setup_scripts()
290
+
291
+ def setup_scripts(self):
292
+ self.scripts_setup_complete = True
293
+
294
+ self.scripts.setup_scrips(self, is_ui=not self.is_api)
295
+
296
+ def comment(self, text):
297
+ self.comments[text] = 1
298
+
299
+ def txt2img_image_conditioning(self, x, width=None, height=None):
300
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
301
+
302
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
303
+
304
+ def depth2img_image_conditioning(self, source_image):
305
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
306
+ transformer = AddMiDaS(model_type="dpt_hybrid")
307
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
308
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
309
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
310
+
311
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
312
+ conditioning = torch.nn.functional.interpolate(
313
+ self.sd_model.depth_model(midas_in),
314
+ size=conditioning_image.shape[2:],
315
+ mode="bicubic",
316
+ align_corners=False,
317
+ )
318
+
319
+ (depth_min, depth_max) = torch.aminmax(conditioning)
320
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
321
+ return conditioning
322
+
323
+ def edit_image_conditioning(self, source_image):
324
+ conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
325
+
326
+ return conditioning_image
327
+
328
+ def unclip_image_conditioning(self, source_image):
329
+ c_adm = self.sd_model.embedder(source_image)
330
+ if self.sd_model.noise_augmentor is not None:
331
+ noise_level = 0 # TODO: Allow other noise levels?
332
+ c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
333
+ c_adm = torch.cat((c_adm, noise_level_emb), 1)
334
+ return c_adm
335
+
336
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
337
+ self.is_using_inpainting_conditioning = True
338
+
339
+ # Handle the different mask inputs
340
+ if image_mask is not None:
341
+ if torch.is_tensor(image_mask):
342
+ conditioning_mask = image_mask
343
+ else:
344
+ conditioning_mask = np.array(image_mask.convert("L"))
345
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
346
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
347
+
348
+ if round_image_mask:
349
+ # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
350
+ conditioning_mask = torch.round(conditioning_mask)
351
+
352
+ else:
353
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
354
+
355
+ # Create another latent image, this time with a masked version of the original input.
356
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
357
+ conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
358
+ conditioning_image = torch.lerp(
359
+ source_image,
360
+ source_image * (1.0 - conditioning_mask),
361
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
362
+ )
363
+
364
+ # Encode the new masked image using first stage of network.
365
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
366
+
367
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
368
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
369
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
370
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
371
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
372
+
373
+ return image_conditioning
374
+
375
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
376
+ source_image = devices.cond_cast_float(source_image)
377
+
378
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
379
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
380
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
381
+ return self.depth2img_image_conditioning(source_image)
382
+
383
+ if self.sd_model.cond_stage_key == "edit":
384
+ return self.edit_image_conditioning(source_image)
385
+
386
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
387
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
388
+
389
+ if self.sampler.conditioning_key == "crossattn-adm":
390
+ return self.unclip_image_conditioning(source_image)
391
+
392
+ if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:
393
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
394
+
395
+ # Dummy zero conditioning if we're not using inpainting or depth model.
396
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
397
+
398
+ def init(self, all_prompts, all_seeds, all_subseeds):
399
+ pass
400
+
401
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
402
+ raise NotImplementedError()
403
+
404
+ def close(self):
405
+ self.sampler = None
406
+ self.c = None
407
+ self.uc = None
408
+ if not opts.persistent_cond_cache:
409
+ StableDiffusionProcessing.cached_c = [None, None]
410
+ StableDiffusionProcessing.cached_uc = [None, None]
411
+
412
+ def get_token_merging_ratio(self, for_hr=False):
413
+ if for_hr:
414
+ return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
415
+
416
+ return self.token_merging_ratio or opts.token_merging_ratio
417
+
418
+ def setup_prompts(self):
419
+ if isinstance(self.prompt,list):
420
+ self.all_prompts = self.prompt
421
+ elif isinstance(self.negative_prompt, list):
422
+ self.all_prompts = [self.prompt] * len(self.negative_prompt)
423
+ else:
424
+ self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
425
+
426
+ if isinstance(self.negative_prompt, list):
427
+ self.all_negative_prompts = self.negative_prompt
428
+ else:
429
+ self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)
430
+
431
+ if len(self.all_prompts) != len(self.all_negative_prompts):
432
+ raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
433
+
434
+ self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
435
+ self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
436
+
437
+ self.main_prompt = self.all_prompts[0]
438
+ self.main_negative_prompt = self.all_negative_prompts[0]
439
+
440
+ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
441
+ """Returns parameters that invalidate the cond cache if changed"""
442
+
443
+ return (
444
+ required_prompts,
445
+ steps,
446
+ hires_steps,
447
+ use_old_scheduling,
448
+ opts.CLIP_stop_at_last_layers,
449
+ shared.sd_model.sd_checkpoint_info,
450
+ extra_network_data,
451
+ opts.sdxl_crop_left,
452
+ opts.sdxl_crop_top,
453
+ self.width,
454
+ self.height,
455
+ opts.fp8_storage,
456
+ opts.cache_fp16_weight,
457
+ opts.emphasis,
458
+ )
459
+
460
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
461
+ """
462
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
463
+ using a cache to store the result if the same arguments have been used before.
464
+
465
+ cache is an array containing two elements. The first element is a tuple
466
+ representing the previously used arguments, or None if no arguments
467
+ have been used before. The second element is where the previously
468
+ computed result is stored.
469
+
470
+ caches is a list with items described above.
471
+ """
472
+
473
+ if shared.opts.use_old_scheduling:
474
+ old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
475
+ new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
476
+ if old_schedules != new_schedules:
477
+ self.extra_generation_params["Old prompt editing timelines"] = True
478
+
479
+ cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
480
+
481
+ for cache in caches:
482
+ if cache[0] is not None and cached_params == cache[0]:
483
+ return cache[1]
484
+
485
+ cache = caches[0]
486
+
487
+ with devices.autocast():
488
+ cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
489
+
490
+ cache[0] = cached_params
491
+ return cache[1]
492
+
493
+ def setup_conds(self):
494
+ prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
495
+ negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
496
+
497
+ sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
498
+ total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
499
+ self.step_multiplier = total_steps // self.steps
500
+ self.firstpass_steps = total_steps
501
+
502
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
503
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
504
+
505
+ def get_conds(self):
506
+ return self.c, self.uc
507
+
508
+ def parse_extra_network_prompts(self):
509
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
510
+
511
+ def save_samples(self) -> bool:
512
+ """Returns whether generated images need to be written to disk"""
513
+ return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
514
+
515
+
516
+ class Processed:
517
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
518
+ self.images = images_list
519
+ self.prompt = p.prompt
520
+ self.negative_prompt = p.negative_prompt
521
+ self.seed = seed
522
+ self.subseed = subseed
523
+ self.subseed_strength = p.subseed_strength
524
+ self.info = info
525
+ self.comments = "".join(f"{comment}\n" for comment in p.comments)
526
+ self.width = p.width
527
+ self.height = p.height
528
+ self.sampler_name = p.sampler_name
529
+ self.cfg_scale = p.cfg_scale
530
+ self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
531
+ self.steps = p.steps
532
+ self.batch_size = p.batch_size
533
+ self.restore_faces = p.restore_faces
534
+ self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
535
+ self.sd_model_name = p.sd_model_name
536
+ self.sd_model_hash = p.sd_model_hash
537
+ self.sd_vae_name = p.sd_vae_name
538
+ self.sd_vae_hash = p.sd_vae_hash
539
+ self.seed_resize_from_w = p.seed_resize_from_w
540
+ self.seed_resize_from_h = p.seed_resize_from_h
541
+ self.denoising_strength = getattr(p, 'denoising_strength', None)
542
+ self.extra_generation_params = p.extra_generation_params
543
+ self.index_of_first_image = index_of_first_image
544
+ self.styles = p.styles
545
+ self.job_timestamp = state.job_timestamp
546
+ self.clip_skip = opts.CLIP_stop_at_last_layers
547
+ self.token_merging_ratio = p.token_merging_ratio
548
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
549
+
550
+ self.eta = p.eta
551
+ self.ddim_discretize = p.ddim_discretize
552
+ self.s_churn = p.s_churn
553
+ self.s_tmin = p.s_tmin
554
+ self.s_tmax = p.s_tmax
555
+ self.s_noise = p.s_noise
556
+ self.s_min_uncond = p.s_min_uncond
557
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
558
+ self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
559
+ self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
560
+ self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
561
+ self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
562
+ self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
563
+
564
+ self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
565
+ self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
566
+ self.all_seeds = all_seeds or p.all_seeds or [self.seed]
567
+ self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
568
+ self.infotexts = infotexts or [info] * len(images_list)
569
+ self.version = program_version()
570
+
571
+ def js(self):
572
+ obj = {
573
+ "prompt": self.all_prompts[0],
574
+ "all_prompts": self.all_prompts,
575
+ "negative_prompt": self.all_negative_prompts[0],
576
+ "all_negative_prompts": self.all_negative_prompts,
577
+ "seed": self.seed,
578
+ "all_seeds": self.all_seeds,
579
+ "subseed": self.subseed,
580
+ "all_subseeds": self.all_subseeds,
581
+ "subseed_strength": self.subseed_strength,
582
+ "width": self.width,
583
+ "height": self.height,
584
+ "sampler_name": self.sampler_name,
585
+ "cfg_scale": self.cfg_scale,
586
+ "steps": self.steps,
587
+ "batch_size": self.batch_size,
588
+ "restore_faces": self.restore_faces,
589
+ "face_restoration_model": self.face_restoration_model,
590
+ "sd_model_name": self.sd_model_name,
591
+ "sd_model_hash": self.sd_model_hash,
592
+ "sd_vae_name": self.sd_vae_name,
593
+ "sd_vae_hash": self.sd_vae_hash,
594
+ "seed_resize_from_w": self.seed_resize_from_w,
595
+ "seed_resize_from_h": self.seed_resize_from_h,
596
+ "denoising_strength": self.denoising_strength,
597
+ "extra_generation_params": self.extra_generation_params,
598
+ "index_of_first_image": self.index_of_first_image,
599
+ "infotexts": self.infotexts,
600
+ "styles": self.styles,
601
+ "job_timestamp": self.job_timestamp,
602
+ "clip_skip": self.clip_skip,
603
+ "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
604
+ "version": self.version,
605
+ }
606
+
607
+ return json.dumps(obj, default=lambda o: None)
608
+
609
+ def infotext(self, p: StableDiffusionProcessing, index):
610
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
611
+
612
+ def get_token_merging_ratio(self, for_hr=False):
613
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
614
+
615
+
616
+ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
617
+ g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
618
+ return g.next()
619
+
620
+
621
+ class DecodedSamples(list):
622
+ already_decoded = True
623
+
624
+
625
+ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
626
+ samples = DecodedSamples()
627
+
628
+ if check_for_nans:
629
+ devices.test_for_nans(batch, "unet")
630
+
631
+ for i in range(batch.shape[0]):
632
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
633
+
634
+ if check_for_nans:
635
+
636
+ try:
637
+ devices.test_for_nans(sample, "vae")
638
+ except devices.NansException as e:
639
+ if shared.opts.auto_vae_precision_bfloat16:
640
+ autofix_dtype = torch.bfloat16
641
+ autofix_dtype_text = "bfloat16"
642
+ autofix_dtype_setting = "Automatically convert VAE to bfloat16"
643
+ autofix_dtype_comment = ""
644
+ elif shared.opts.auto_vae_precision:
645
+ autofix_dtype = torch.float32
646
+ autofix_dtype_text = "32-bit float"
647
+ autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
648
+ autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
649
+ else:
650
+ raise e
651
+
652
+ if devices.dtype_vae == autofix_dtype:
653
+ raise e
654
+
655
+ errors.print_error_explanation(
656
+ "A tensor with all NaNs was produced in VAE.\n"
657
+ f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
658
+ f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
659
+ )
660
+
661
+ devices.dtype_vae = autofix_dtype
662
+ model.first_stage_model.to(devices.dtype_vae)
663
+ batch = batch.to(devices.dtype_vae)
664
+
665
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
666
+
667
+ if target_device is not None:
668
+ sample = sample.to(target_device)
669
+
670
+ samples.append(sample)
671
+
672
+ return samples
673
+
674
+
675
+ def get_fixed_seed(seed):
676
+ if seed == '' or seed is None:
677
+ seed = -1
678
+ elif isinstance(seed, str):
679
+ try:
680
+ seed = int(seed)
681
+ except Exception:
682
+ seed = -1
683
+
684
+ if seed == -1:
685
+ return int(random.randrange(4294967294))
686
+
687
+ return seed
688
+
689
+
690
+ def fix_seed(p):
691
+ p.seed = get_fixed_seed(p.seed)
692
+ p.subseed = get_fixed_seed(p.subseed)
693
+
694
+
695
+ def program_version():
696
+ import launch
697
+
698
+ res = launch.git_tag()
699
+ if res == "<none>":
700
+ res = None
701
+
702
+ return res
703
+
704
+
705
+ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
706
+ """
707
+ this function is used to generate the infotext that is stored in the generated images, it's contains the parameters that are required to generate the imagee
708
+ Args:
709
+ p: StableDiffusionProcessing
710
+ all_prompts: list[str]
711
+ all_seeds: list[int]
712
+ all_subseeds: list[int]
713
+ comments: list[str]
714
+ iteration: int
715
+ position_in_batch: int
716
+ use_main_prompt: bool
717
+ index: int
718
+ all_negative_prompts: list[str]
719
+
720
+ Returns: str
721
+
722
+ Extra generation params
723
+ p.extra_generation_params dictionary allows for additional parameters to be added to the infotext
724
+ this can be use by the base webui or extensions.
725
+ To add a new entry, add a new key value pair, the dictionary key will be used as the key of the parameter in the infotext
726
+ the value generation_params can be defined as:
727
+ - str | None
728
+ - List[str|None]
729
+ - callable func(**kwargs) -> str | None
730
+
731
+ When defined as a string, it will be used as without extra processing; this is this most common use case.
732
+
733
+ Defining as a list allows for parameter that changes across images in the job, for example, the 'Seed' parameter.
734
+ The list should have the same length as the total number of images in the entire job.
735
+
736
+ Defining as a callable function allows parameter cannot be generated earlier or when extra logic is required.
737
+ For example 'Hires prompt', due to reasons the hr_prompt might be changed by process in the pipeline or extensions
738
+ and may vary across different images, defining as a static string or list would not work.
739
+
740
+ The function takes locals() as **kwargs, as such will have access to variables like 'p' and 'index'.
741
+ the base signature of the function should be:
742
+ func(**kwargs) -> str | None
743
+ optionally it can have additional arguments that will be used in the function:
744
+ func(p, index, **kwargs) -> str | None
745
+ note: for better future compatibility even though this function will have access to all variables in the locals(),
746
+ it is recommended to only use the arguments present in the function signature of create_infotext.
747
+ For actual implementation examples, see StableDiffusionProcessingTxt2Img.init > get_hr_prompt.
748
+ """
749
+
750
+ if use_main_prompt:
751
+ index = 0
752
+ elif index is None:
753
+ index = position_in_batch + iteration * p.batch_size
754
+
755
+ if all_negative_prompts is None:
756
+ all_negative_prompts = p.all_negative_prompts
757
+
758
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
759
+ enable_hr = getattr(p, 'enable_hr', False)
760
+ token_merging_ratio = p.get_token_merging_ratio()
761
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
762
+
763
+ prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
764
+ negative_prompt = p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]
765
+
766
+ uses_ensd = opts.eta_noise_seed_delta != 0
767
+ if uses_ensd:
768
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
769
+
770
+ generation_params = {
771
+ "Steps": p.steps,
772
+ "Sampler": p.sampler_name,
773
+ "Schedule type": p.scheduler,
774
+ "CFG scale": p.cfg_scale,
775
+ "Image CFG scale": getattr(p, 'image_cfg_scale', None),
776
+ "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
777
+ "Face restoration": opts.face_restoration_model if p.restore_faces else None,
778
+ "Size": f"{p.width}x{p.height}",
779
+ "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
780
+ "Model": p.sd_model_name if opts.add_model_name_to_info else None,
781
+ "FP8 weight": opts.fp8_storage if devices.fp8 else None,
782
+ "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
783
+ "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
784
+ "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
785
+ "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
786
+ "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
787
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
788
+ "Denoising strength": p.extra_generation_params.get("Denoising strength"),
789
+ "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
790
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
791
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
792
+ "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
793
+ "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
794
+ "Init image hash": getattr(p, 'init_img_hash', None),
795
+ "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
796
+ "Tiling": "True" if p.tiling else None,
797
+ "Progressive Growing": "True" if p.enable_progressive_growing else None,
798
+ "Min Scale": p.progressive_growing_min_scale if p.enable_progressive_growing else None,
799
+ "Max Scale": p.progressive_growing_max_scale if p.enable_progressive_growing else None,
800
+ "Progressive Growing Steps": p.progressive_growing_steps if p.enable_progressive_growing else None,
801
+ "Refinement": "True" if p.progressive_growing_refinement and p.enable_progressive_growing else None,
802
+ **p.extra_generation_params,
803
+ "Version": program_version() if opts.add_version_to_infotext else None,
804
+ "User": p.user if opts.add_user_name_to_info else None,
805
+ }
806
+
807
+ for key, value in generation_params.items():
808
+ try:
809
+ if isinstance(value, list):
810
+ generation_params[key] = value[index]
811
+ elif callable(value):
812
+ generation_params[key] = value(**locals())
813
+ except Exception:
814
+ errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
815
+ generation_params[key] = None
816
+
817
+ generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
818
+
819
+ negative_prompt_text = f"\nNegative prompt: {negative_prompt}" if negative_prompt else ""
820
+
821
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
822
+
823
+
824
+ def process_images(p: StableDiffusionProcessing) -> Processed:
825
+ if p.scripts is not None:
826
+ p.scripts.before_process(p)
827
+
828
+ stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
829
+
830
+ try:
831
+ # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
832
+ # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
833
+ if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
834
+ p.override_settings.pop('sd_model_checkpoint', None)
835
+ sd_models.reload_model_weights()
836
+
837
+ for k, v in p.override_settings.items():
838
+ opts.set(k, v, is_api=True, run_callbacks=False)
839
+
840
+ if k == 'sd_model_checkpoint':
841
+ sd_models.reload_model_weights()
842
+
843
+ if k == 'sd_vae':
844
+ sd_vae.reload_vae_weights()
845
+
846
+ sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
847
+
848
+ # backwards compatibility, fix sampler and scheduler if invalid
849
+ sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
850
+
851
+ with profiling.Profiler():
852
+ res = process_images_inner(p)
853
+
854
+ finally:
855
+ sd_models.apply_token_merging(p.sd_model, 0)
856
+
857
+ # restore opts to original state
858
+ if p.override_settings_restore_afterwards:
859
+ for k, v in stored_opts.items():
860
+ setattr(opts, k, v)
861
+
862
+ if k == 'sd_vae':
863
+ sd_vae.reload_vae_weights()
864
+
865
+ return res
866
+
867
+
868
+ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
869
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
870
+
871
+ if isinstance(p.prompt, list):
872
+ assert(len(p.prompt) > 0)
873
+ else:
874
+ assert p.prompt is not None
875
+
876
+ devices.torch_gc()
877
+
878
+ seed = get_fixed_seed(p.seed)
879
+ subseed = get_fixed_seed(p.subseed)
880
+
881
+ if p.restore_faces is None:
882
+ p.restore_faces = opts.face_restoration
883
+
884
+ if p.tiling is None:
885
+ p.tiling = opts.tiling
886
+
887
+ if p.refiner_checkpoint not in (None, "", "None", "none"):
888
+ p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
889
+ if p.refiner_checkpoint_info is None:
890
+ raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
891
+
892
+ if hasattr(shared.sd_model, 'fix_dimensions'):
893
+ p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)
894
+
895
+ p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
896
+ p.sd_model_hash = shared.sd_model.sd_model_hash
897
+ p.sd_vae_name = sd_vae.get_loaded_vae_name()
898
+ p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
899
+
900
+ modules.sd_hijack.model_hijack.apply_circular(p.tiling)
901
+ modules.sd_hijack.model_hijack.clear_comments()
902
+
903
+ p.fill_fields_from_opts()
904
+ p.setup_prompts()
905
+
906
+ if isinstance(seed, list):
907
+ p.all_seeds = seed
908
+ else:
909
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
910
+
911
+ if isinstance(subseed, list):
912
+ p.all_subseeds = subseed
913
+ else:
914
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
915
+
916
+ if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
917
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
918
+
919
+ if p.scripts is not None:
920
+ p.scripts.process(p)
921
+
922
+ infotexts = []
923
+ output_images = []
924
+ with torch.no_grad(), p.sd_model.ema_scope():
925
+ with devices.autocast():
926
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
927
+
928
+ # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
929
+ if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
930
+ sd_vae_approx.model()
931
+
932
+ sd_unet.apply_unet()
933
+
934
+ if state.job_count == -1:
935
+ state.job_count = p.n_iter
936
+
937
+ for n in range(p.n_iter):
938
+ p.iteration = n
939
+
940
+ if state.skipped:
941
+ state.skipped = False
942
+
943
+ if state.interrupted or state.stopping_generation:
944
+ break
945
+
946
+ sd_models.reload_model_weights() # model can be changed for example by refiner
947
+
948
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
949
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
950
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
951
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
952
+
953
+ latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
954
+ p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
955
+
956
+ if p.scripts is not None:
957
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
958
+
959
+ if len(p.prompts) == 0:
960
+ break
961
+
962
+ p.parse_extra_network_prompts()
963
+
964
+ if not p.disable_extra_networks:
965
+ with devices.autocast():
966
+ extra_networks.activate(p, p.extra_network_data)
967
+
968
+ if p.scripts is not None:
969
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
970
+
971
+ p.setup_conds()
972
+
973
+ p.extra_generation_params.update(model_hijack.extra_generation_params)
974
+
975
+ # params.txt should be saved after scripts.process_batch, since the
976
+ # infotext could be modified by that callback
977
+ # Example: a wildcard processed by process_batch sets an extra model
978
+ # strength, which is saved as "Model Strength: 1.0" in the infotext
979
+ if n == 0 and not cmd_opts.no_prompt_history:
980
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
981
+ processed = Processed(p, [])
982
+ file.write(processed.infotext(p, 0))
983
+
984
+ for comment in model_hijack.comments:
985
+ p.comment(comment)
986
+
987
+ if p.n_iter > 1:
988
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
989
+
990
+ sd_models.apply_alpha_schedule_override(p.sd_model, p)
991
+
992
+ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
993
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
994
+
995
+ if p.scripts is not None:
996
+ ps = scripts.PostSampleArgs(samples_ddim)
997
+ p.scripts.post_sample(p, ps)
998
+ samples_ddim = ps.samples
999
+
1000
+ if getattr(samples_ddim, 'already_decoded', False):
1001
+ x_samples_ddim = samples_ddim
1002
+ else:
1003
+ devices.test_for_nans(samples_ddim, "unet")
1004
+
1005
+ if opts.sd_vae_decode_method != 'Full':
1006
+ p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
1007
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
1008
+
1009
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
1010
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
1011
+
1012
+ del samples_ddim
1013
+
1014
+ if lowvram.is_enabled(shared.sd_model):
1015
+ lowvram.send_everything_to_cpu()
1016
+
1017
+ devices.torch_gc()
1018
+
1019
+ state.nextjob()
1020
+
1021
+ if p.scripts is not None:
1022
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
1023
+
1024
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1025
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1026
+
1027
+ batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
1028
+ p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
1029
+ x_samples_ddim = batch_params.images
1030
+
1031
+ def infotext(index=0, use_main_prompt=False):
1032
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
1033
+
1034
+ save_samples = p.save_samples()
1035
+
1036
+ for i, x_sample in enumerate(x_samples_ddim):
1037
+ p.batch_index = i
1038
+
1039
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1040
+ x_sample = x_sample.astype(np.uint8)
1041
+
1042
+ if p.restore_faces:
1043
+ if save_samples and opts.save_images_before_face_restoration:
1044
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
1045
+
1046
+ devices.torch_gc()
1047
+
1048
+ x_sample = modules.face_restoration.restore_faces(x_sample)
1049
+ devices.torch_gc()
1050
+
1051
+ image = Image.fromarray(x_sample)
1052
+
1053
+ if p.scripts is not None:
1054
+ pp = scripts.PostprocessImageArgs(image)
1055
+ p.scripts.postprocess_image(p, pp)
1056
+ image = pp.image
1057
+
1058
+ mask_for_overlay = getattr(p, "mask_for_overlay", None)
1059
+
1060
+ if not shared.opts.overlay_inpaint:
1061
+ overlay_image = None
1062
+ elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
1063
+ overlay_image = p.overlay_images[i]
1064
+ else:
1065
+ overlay_image = None
1066
+
1067
+ if p.scripts is not None:
1068
+ ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
1069
+ p.scripts.postprocess_maskoverlay(p, ppmo)
1070
+ mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
1071
+
1072
+ if p.color_corrections is not None and i < len(p.color_corrections):
1073
+ if save_samples and opts.save_images_before_color_correction:
1074
+ image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
1075
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
1076
+ image = apply_color_correction(p.color_corrections[i], image)
1077
+
1078
+ # If the intention is to show the output from the model
1079
+ # that is being composited over the original image,
1080
+ # we need to keep the original image around
1081
+ # and use it in the composite step.
1082
+ image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)
1083
+
1084
+ if p.scripts is not None:
1085
+ pp = scripts.PostprocessImageArgs(image)
1086
+ p.scripts.postprocess_image_after_composite(p, pp)
1087
+ image = pp.image
1088
+
1089
+ if save_samples:
1090
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
1091
+
1092
+ text = infotext(i)
1093
+ infotexts.append(text)
1094
+ if opts.enable_pnginfo:
1095
+ image.info["parameters"] = text
1096
+ output_images.append(image)
1097
+
1098
+ if mask_for_overlay is not None:
1099
+ if opts.return_mask or opts.save_mask:
1100
+ image_mask = mask_for_overlay.convert('RGB')
1101
+ if save_samples and opts.save_mask:
1102
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
1103
+ if opts.return_mask:
1104
+ output_images.append(image_mask)
1105
+
1106
+ if opts.return_mask_composite or opts.save_mask_composite:
1107
+ image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
1108
+ if save_samples and opts.save_mask_composite:
1109
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
1110
+ if opts.return_mask_composite:
1111
+ output_images.append(image_mask_composite)
1112
+
1113
+ del x_samples_ddim
1114
+
1115
+ devices.torch_gc()
1116
+
1117
+ if not infotexts:
1118
+ infotexts.append(Processed(p, []).infotext(p, 0))
1119
+
1120
+ p.color_corrections = None
1121
+
1122
+ index_of_first_image = 0
1123
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
1124
+ if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
1125
+ grid = images.image_grid(output_images, p.batch_size)
1126
+
1127
+ if opts.return_grid:
1128
+ text = infotext(use_main_prompt=True)
1129
+ infotexts.insert(0, text)
1130
+ if opts.enable_pnginfo:
1131
+ grid.info["parameters"] = text
1132
+ output_images.insert(0, grid)
1133
+ index_of_first_image = 1
1134
+ if opts.grid_save:
1135
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
1136
+
1137
+ if not p.disable_extra_networks and p.extra_network_data:
1138
+ extra_networks.deactivate(p, p.extra_network_data)
1139
+
1140
+ devices.torch_gc()
1141
+
1142
+ res = Processed(
1143
+ p,
1144
+ images_list=output_images,
1145
+ seed=p.all_seeds[0],
1146
+ info=infotexts[0],
1147
+ subseed=p.all_subseeds[0],
1148
+ index_of_first_image=index_of_first_image,
1149
+ infotexts=infotexts,
1150
+ )
1151
+
1152
+ if p.scripts is not None:
1153
+ p.scripts.postprocess(p, res)
1154
+
1155
+ return res
1156
+
1157
+
1158
+ def old_hires_fix_first_pass_dimensions(width, height):
1159
+ """old algorithm for auto-calculating first pass size"""
1160
+
1161
+ desired_pixel_count = 512 * 512
1162
+ actual_pixel_count = width * height
1163
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
1164
+ width = math.ceil(scale * width / 64) * 64
1165
+ height = math.ceil(scale * height / 64) * 64
1166
+
1167
+ return width, height
1168
+
1169
+
1170
+ @dataclass(repr=False)
1171
+ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
1172
+ enable_hr: bool = False
1173
+ denoising_strength: float = 0.75
1174
+ firstphase_width: int = 0
1175
+ firstphase_height: int = 0
1176
+ hr_scale: float = 2.0
1177
+ hr_upscaler: str = None
1178
+ hr_second_pass_steps: int = 0
1179
+ hr_resize_x: int = 0
1180
+ hr_resize_y: int = 0
1181
+ hr_checkpoint_name: str = None
1182
+ hr_sampler_name: str = None
1183
+ hr_scheduler: str = None
1184
+ hr_prompt: str = ''
1185
+ hr_negative_prompt: str = ''
1186
+ force_task_id: str = None
1187
+
1188
+ cached_hr_uc = [None, None]
1189
+ cached_hr_c = [None, None]
1190
+
1191
+ hr_checkpoint_info: dict = field(default=None, init=False)
1192
+ hr_upscale_to_x: int = field(default=0, init=False)
1193
+ hr_upscale_to_y: int = field(default=0, init=False)
1194
+ truncate_x: int = field(default=0, init=False)
1195
+ truncate_y: int = field(default=0, init=False)
1196
+ applied_old_hires_behavior_to: tuple = field(default=None, init=False)
1197
+ latent_scale_mode: dict = field(default=None, init=False)
1198
+ hr_c: tuple | None = field(default=None, init=False)
1199
+ hr_uc: tuple | None = field(default=None, init=False)
1200
+ all_hr_prompts: list = field(default=None, init=False)
1201
+ all_hr_negative_prompts: list = field(default=None, init=False)
1202
+ hr_prompts: list = field(default=None, init=False)
1203
+ hr_negative_prompts: list = field(default=None, init=False)
1204
+ hr_extra_network_data: list = field(default=None, init=False)
1205
+ enable_progressive_growing: bool = field(default=False, init=False)
1206
+ progressive_growing_min_scale: float = field(default=0.25, init=False)
1207
+ progressive_growing_max_scale: float = field(default=1.0, init=False)
1208
+ progressive_growing_steps: int = field(default=4, init=False)
1209
+ progressive_growing_refinement: bool = field(default=True, init=False)
1210
+
1211
+ def __post_init__(self):
1212
+ super().__post_init__()
1213
+
1214
+ self.enable_progressive_growing = getattr(self, 'enable_progressive_growing', False)
1215
+ self.progressive_growing_min_scale = getattr(self, 'progressive_growing_min_scale', 0.25)
1216
+ self.progressive_growing_max_scale = getattr(self, 'progressive_growing_max_scale', 1.0)
1217
+ self.progressive_growing_steps = getattr(self, 'progressive_growing_steps', 4)
1218
+ self.progressive_growing_refinement = getattr(self, 'progressive_growing_refinement', True)
1219
+
1220
+ def __post_init__(self):
1221
+ super().__post_init__()
1222
+
1223
+ if self.firstphase_width != 0 or self.firstphase_height != 0:
1224
+ self.hr_upscale_to_x = self.width
1225
+ self.hr_upscale_to_y = self.height
1226
+ self.width = self.firstphase_width
1227
+ self.height = self.firstphase_height
1228
+
1229
+ self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
1230
+ self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
1231
+
1232
+ def calculate_target_resolution(self):
1233
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
1234
+ self.hr_resize_x = self.width
1235
+ self.hr_resize_y = self.height
1236
+ self.hr_upscale_to_x = self.width
1237
+ self.hr_upscale_to_y = self.height
1238
+
1239
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
1240
+ self.applied_old_hires_behavior_to = (self.width, self.height)
1241
+
1242
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
1243
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
1244
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
1245
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
1246
+ else:
1247
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
1248
+
1249
+ if self.hr_resize_y == 0:
1250
+ self.hr_upscale_to_x = self.hr_resize_x
1251
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1252
+ elif self.hr_resize_x == 0:
1253
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1254
+ self.hr_upscale_to_y = self.hr_resize_y
1255
+ else:
1256
+ target_w = self.hr_resize_x
1257
+ target_h = self.hr_resize_y
1258
+ src_ratio = self.width / self.height
1259
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
1260
+
1261
+ if src_ratio < dst_ratio:
1262
+ self.hr_upscale_to_x = self.hr_resize_x
1263
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
1264
+ else:
1265
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
1266
+ self.hr_upscale_to_y = self.hr_resize_y
1267
+
1268
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
1269
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
1270
+
1271
+ def init(self, all_prompts, all_seeds, all_subseeds):
1272
+ if self.enable_hr:
1273
+ self.extra_generation_params["Denoising strength"] = self.denoising_strength
1274
+
1275
+ if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
1276
+ self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
1277
+
1278
+ if self.hr_checkpoint_info is None:
1279
+ raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
1280
+
1281
+ self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
1282
+
1283
+ if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
1284
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
1285
+
1286
+ def get_hr_prompt(p, index, prompt_text, **kwargs):
1287
+ hr_prompt = p.all_hr_prompts[index]
1288
+ return hr_prompt if hr_prompt != prompt_text else None
1289
+
1290
+ def get_hr_negative_prompt(p, index, negative_prompt, **kwargs):
1291
+ hr_negative_prompt = p.all_hr_negative_prompts[index]
1292
+ return hr_negative_prompt if hr_negative_prompt != negative_prompt else None
1293
+
1294
+ self.extra_generation_params["Hires prompt"] = get_hr_prompt
1295
+ self.extra_generation_params["Hires negative prompt"] = get_hr_negative_prompt
1296
+
1297
+ self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
1298
+
1299
+ if self.hr_scheduler is None:
1300
+ self.hr_scheduler = self.scheduler
1301
+
1302
+ self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
1303
+ if self.enable_hr and self.latent_scale_mode is None:
1304
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
1305
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
1306
+
1307
+ self.calculate_target_resolution()
1308
+
1309
+ if not state.processing_has_refined_job_count:
1310
+ if state.job_count == -1:
1311
+ state.job_count = self.n_iter
1312
+ if getattr(self, 'txt2img_upscale', False):
1313
+ total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
1314
+ else:
1315
+ total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
1316
+ shared.total_tqdm.updateTotal(total_steps)
1317
+ state.job_count = state.job_count * 2
1318
+ state.processing_has_refined_job_count = True
1319
+
1320
+ if self.hr_second_pass_steps:
1321
+ self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1322
+
1323
+ if self.hr_upscaler is not None:
1324
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1325
+
1326
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1327
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1328
+
1329
+ if self.enable_progressive_growing:
1330
+ return self.sample_progressive(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts)
1331
+
1332
+ if self.firstpass_image is not None and self.enable_hr:
1333
+ # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
1334
+
1335
+ if self.latent_scale_mode is None:
1336
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
1337
+ image = np.moveaxis(image, 2, 0)
1338
+
1339
+ samples = None
1340
+ decoded_samples = torch.asarray(np.expand_dims(image, 0))
1341
+
1342
+ else:
1343
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0
1344
+ image = np.moveaxis(image, 2, 0)
1345
+ image = torch.from_numpy(np.expand_dims(image, axis=0))
1346
+ image = image.to(shared.device, dtype=devices.dtype_vae)
1347
+
1348
+ if opts.sd_vae_encode_method != 'Full':
1349
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1350
+
1351
+ samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1352
+ decoded_samples = None
1353
+ devices.torch_gc()
1354
+
1355
+ else:
1356
+ # here we generate an image normally
1357
+
1358
+ x = self.rng.next()
1359
+ if self.scripts is not None:
1360
+ self.scripts.process_before_every_sampling(
1361
+ p=self,
1362
+ x=x,
1363
+ noise=x,
1364
+ c=conditioning,
1365
+ uc=unconditional_conditioning
1366
+ )
1367
+
1368
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1369
+ del x
1370
+
1371
+ if not self.enable_hr:
1372
+ return samples
1373
+
1374
+ devices.torch_gc()
1375
+
1376
+ if self.latent_scale_mode is None:
1377
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
1378
+ else:
1379
+ decoded_samples = None
1380
+
1381
+ with sd_models.SkipWritingToConfig():
1382
+ sd_models.reload_model_weights(info=self.hr_checkpoint_info)
1383
+
1384
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
1385
+
1386
+ def sample_progressive(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1387
+ is_sdxl = getattr(self.sd_model, 'is_sdxl', False)
1388
+
1389
+ if is_sdxl:
1390
+ min_scale = max(0.5, self.progressive_growing_min_scale)
1391
+ else:
1392
+ min_scale = self.progressive_growing_min_scale
1393
+
1394
+ resolution_steps = np.linspace(min_scale, self.progressive_growing_max_scale, self.progressive_growing_steps)
1395
+
1396
+ initial_width = max(512 if is_sdxl else 64, int(self.width * resolution_steps[0]))
1397
+ initial_height = max(512 if is_sdxl else 64, int(self.height * resolution_steps[0]))
1398
+
1399
+ x = create_random_tensors((opt_C, initial_height // opt_f, initial_width // opt_f), seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1400
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
1401
+
1402
+ for i in range(1, len(resolution_steps)):
1403
+ target_width = int(self.width * resolution_steps[i])
1404
+ target_height = int(self.height * resolution_steps[i])
1405
+
1406
+ if is_sdxl:
1407
+ target_width = max(512, min(1536, target_width))
1408
+ target_height = max(512, min(1536, target_height))
1409
+
1410
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode='bicubic', align_corners=False)
1411
+
1412
+ if self.progressive_growing_refinement:
1413
+ steps_for_refinement = self.steps // len(resolution_steps)
1414
+ noise = create_random_tensors(samples.shape[1:], seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1415
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
1416
+ decoded_samples = torch.stack(decoded_samples).float()
1417
+ decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1418
+ self.image_conditioning = self.img2img_image_conditioning(decoded_samples * 2 - 1, samples)
1419
+
1420
+ samples = self.sampler.sample_img2img(
1421
+ self,
1422
+ samples,
1423
+ noise,
1424
+ conditioning,
1425
+ unconditional_conditioning,
1426
+ steps=steps_for_refinement,
1427
+ image_conditioning=self.image_conditioning
1428
+ )
1429
+
1430
+ return samples
1431
+
1432
+ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
1433
+ if shared.state.interrupted:
1434
+ return samples
1435
+
1436
+ self.is_hr_pass = True
1437
+ target_width = self.hr_upscale_to_x
1438
+ target_height = self.hr_upscale_to_y
1439
+
1440
+ def save_intermediate(image, index):
1441
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
1442
+
1443
+ if not self.save_samples() or not opts.save_images_before_highres_fix:
1444
+ return
1445
+
1446
+ if not isinstance(image, Image.Image):
1447
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
1448
+
1449
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
1450
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1451
+
1452
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
1453
+
1454
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
1455
+
1456
+ if self.latent_scale_mode is not None:
1457
+ for i in range(samples.shape[0]):
1458
+ save_intermediate(samples, i)
1459
+
1460
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
1461
+
1462
+ # Avoid making the inpainting conditioning unless necessary as
1463
+ # this does need some extra compute to decode / encode the image again.
1464
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
1465
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
1466
+ else:
1467
+ image_conditioning = self.txt2img_image_conditioning(samples)
1468
+ else:
1469
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1470
+
1471
+ batch_images = []
1472
+ for i, x_sample in enumerate(lowres_samples):
1473
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
1474
+ x_sample = x_sample.astype(np.uint8)
1475
+ image = Image.fromarray(x_sample)
1476
+
1477
+ save_intermediate(image, i)
1478
+
1479
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1480
+ image = np.array(image).astype(np.float32) / 255.0
1481
+ image = np.moveaxis(image, 2, 0)
1482
+ batch_images.append(image)
1483
+
1484
+ decoded_samples = torch.from_numpy(np.array(batch_images))
1485
+ decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
1486
+
1487
+ if opts.sd_vae_encode_method != 'Full':
1488
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1489
+ samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
1490
+
1491
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1492
+
1493
+ shared.state.nextjob()
1494
+
1495
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
1496
+
1497
+ self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
1498
+ noise = self.rng.next()
1499
+
1500
+ # GC now before running the next img2img to prevent running out of memory
1501
+ devices.torch_gc()
1502
+
1503
+ if not self.disable_extra_networks:
1504
+ with devices.autocast():
1505
+ extra_networks.activate(self, self.hr_extra_network_data)
1506
+
1507
+ with devices.autocast():
1508
+ self.calculate_hr_conds()
1509
+
1510
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1511
+
1512
+ if self.scripts is not None:
1513
+ self.scripts.before_hr(self)
1514
+ self.scripts.process_before_every_sampling(
1515
+ p=self,
1516
+ x=samples,
1517
+ noise=noise,
1518
+ c=self.hr_c,
1519
+ uc=self.hr_uc,
1520
+ )
1521
+
1522
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
1523
+
1524
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
1525
+
1526
+ self.sampler = None
1527
+ devices.torch_gc()
1528
+
1529
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
1530
+
1531
+ self.is_hr_pass = False
1532
+ return decoded_samples
1533
+
1534
+ def close(self):
1535
+ super().close()
1536
+ self.hr_c = None
1537
+ self.hr_uc = None
1538
+ if not opts.persistent_cond_cache:
1539
+ StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
1540
+ StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1541
+
1542
+ def setup_prompts(self):
1543
+ super().setup_prompts()
1544
+
1545
+ if not self.enable_hr:
1546
+ return
1547
+
1548
+ if self.hr_prompt == '':
1549
+ self.hr_prompt = self.prompt
1550
+
1551
+ if self.hr_negative_prompt == '':
1552
+ self.hr_negative_prompt = self.negative_prompt
1553
+
1554
+ if isinstance(self.hr_prompt, list):
1555
+ self.all_hr_prompts = self.hr_prompt
1556
+ else:
1557
+ self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
1558
+
1559
+ if isinstance(self.hr_negative_prompt, list):
1560
+ self.all_hr_negative_prompts = self.hr_negative_prompt
1561
+ else:
1562
+ self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
1563
+
1564
+ self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
1565
+ self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
1566
+
1567
+ def calculate_hr_conds(self):
1568
+ if self.hr_c is not None:
1569
+ return
1570
+
1571
+ hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
1572
+ hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
1573
+
1574
+ sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
1575
+ steps = self.hr_second_pass_steps or self.steps
1576
+ total_steps = sampler_config.total_steps(steps) if sampler_config else steps
1577
+
1578
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
1579
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
1580
+
1581
+ def setup_conds(self):
1582
+ if self.is_hr_pass:
1583
+ # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
1584
+ self.hr_c = None
1585
+ self.calculate_hr_conds()
1586
+ return
1587
+
1588
+ super().setup_conds()
1589
+
1590
+ self.hr_uc = None
1591
+ self.hr_c = None
1592
+
1593
+ if self.enable_hr and self.hr_checkpoint_info is None:
1594
+ if shared.opts.hires_fix_use_firstpass_conds:
1595
+ self.calculate_hr_conds()
1596
+
1597
+ elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
1598
+ with devices.autocast():
1599
+ extra_networks.activate(self, self.hr_extra_network_data)
1600
+
1601
+ self.calculate_hr_conds()
1602
+
1603
+ with devices.autocast():
1604
+ extra_networks.activate(self, self.extra_network_data)
1605
+
1606
+ def get_conds(self):
1607
+ if self.is_hr_pass:
1608
+ return self.hr_c, self.hr_uc
1609
+
1610
+ return super().get_conds()
1611
+
1612
+ def parse_extra_network_prompts(self):
1613
+ res = super().parse_extra_network_prompts()
1614
+
1615
+ if self.enable_hr:
1616
+ self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1617
+ self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
1618
+
1619
+ self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
1620
+
1621
+ return res
1622
+
1623
+
1624
+ @dataclass(repr=False)
1625
+ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
1626
+ init_images: list = None
1627
+ resize_mode: int = 0
1628
+ denoising_strength: float = 0.75
1629
+ image_cfg_scale: float = None
1630
+ mask: Any = None
1631
+ mask_blur_x: int = 4
1632
+ mask_blur_y: int = 4
1633
+ mask_blur: int = None
1634
+ mask_round: bool = True
1635
+ inpainting_fill: int = 0
1636
+ inpaint_full_res: bool = True
1637
+ inpaint_full_res_padding: int = 0
1638
+ inpainting_mask_invert: int = 0
1639
+ initial_noise_multiplier: float = None
1640
+ latent_mask: Image = None
1641
+ force_task_id: str = None
1642
+
1643
+ image_mask: Any = field(default=None, init=False)
1644
+
1645
+ nmask: torch.Tensor = field(default=None, init=False)
1646
+ image_conditioning: torch.Tensor = field(default=None, init=False)
1647
+ init_img_hash: str = field(default=None, init=False)
1648
+ mask_for_overlay: Image = field(default=None, init=False)
1649
+ init_latent: torch.Tensor = field(default=None, init=False)
1650
+
1651
+ def __post_init__(self):
1652
+ super().__post_init__()
1653
+
1654
+ self.image_mask = self.mask
1655
+ self.mask = None
1656
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
1657
+
1658
+ @property
1659
+ def mask_blur(self):
1660
+ if self.mask_blur_x == self.mask_blur_y:
1661
+ return self.mask_blur_x
1662
+ return None
1663
+
1664
+ @mask_blur.setter
1665
+ def mask_blur(self, value):
1666
+ if isinstance(value, int):
1667
+ self.mask_blur_x = value
1668
+ self.mask_blur_y = value
1669
+
1670
+ def init(self, all_prompts, all_seeds, all_subseeds):
1671
+ self.extra_generation_params["Denoising strength"] = self.denoising_strength
1672
+
1673
+ self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
1674
+
1675
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1676
+ crop_region = None
1677
+
1678
+ image_mask = self.image_mask
1679
+
1680
+ if image_mask is not None:
1681
+ # image_mask is passed in as RGBA by Gradio to support alpha masks,
1682
+ # but we still want to support binary masks.
1683
+ image_mask = create_binary_mask(image_mask, round=self.mask_round)
1684
+
1685
+ if self.inpainting_mask_invert:
1686
+ image_mask = ImageOps.invert(image_mask)
1687
+ self.extra_generation_params["Mask mode"] = "Inpaint not masked"
1688
+
1689
+ if self.mask_blur_x > 0:
1690
+ np_mask = np.array(image_mask)
1691
+ kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
1692
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
1693
+ image_mask = Image.fromarray(np_mask)
1694
+
1695
+ if self.mask_blur_y > 0:
1696
+ np_mask = np.array(image_mask)
1697
+ kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
1698
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
1699
+ image_mask = Image.fromarray(np_mask)
1700
+
1701
+ if self.mask_blur_x > 0 or self.mask_blur_y > 0:
1702
+ self.extra_generation_params["Mask blur"] = self.mask_blur
1703
+
1704
+ if self.inpaint_full_res:
1705
+ self.mask_for_overlay = image_mask
1706
+ mask = image_mask.convert('L')
1707
+ crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding)
1708
+ if crop_region:
1709
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1710
+ x1, y1, x2, y2 = crop_region
1711
+ mask = mask.crop(crop_region)
1712
+ image_mask = images.resize_image(2, mask, self.width, self.height)
1713
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
1714
+ self.extra_generation_params["Inpaint area"] = "Only masked"
1715
+ self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
1716
+ else:
1717
+ crop_region = None
1718
+ image_mask = None
1719
+ self.mask_for_overlay = None
1720
+ self.inpaint_full_res = False
1721
+ massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
1722
+ model_hijack.comments.append(massage)
1723
+ logging.info(massage)
1724
+ else:
1725
+ image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1726
+ np_mask = np.array(image_mask)
1727
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1728
+ self.mask_for_overlay = Image.fromarray(np_mask)
1729
+
1730
+ self.overlay_images = []
1731
+
1732
+ latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1733
+
1734
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
1735
+ if add_color_corrections:
1736
+ self.color_corrections = []
1737
+ imgs = []
1738
+ for img in self.init_images:
1739
+
1740
+ # Save init image
1741
+ if opts.save_init_img:
1742
+ self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
1743
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
1744
+
1745
+ image = images.flatten(img, opts.img2img_background_color)
1746
+
1747
+ if crop_region is None and self.resize_mode != 3:
1748
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
1749
+
1750
+ if image_mask is not None:
1751
+ if self.mask_for_overlay.size != (image.width, image.height):
1752
+ self.mask_for_overlay = images.resize_image(self.resize_mode, self.mask_for_overlay, image.width, image.height)
1753
+ image_masked = Image.new('RGBa', (image.width, image.height))
1754
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
1755
+
1756
+ self.overlay_images.append(image_masked.convert('RGBA'))
1757
+
1758
+ # crop_region is not None if we are doing inpaint full res
1759
+ if crop_region is not None:
1760
+ image = image.crop(crop_region)
1761
+ image = images.resize_image(2, image, self.width, self.height)
1762
+
1763
+ if image_mask is not None:
1764
+ if self.inpainting_fill != 1:
1765
+ image = masking.fill(image, latent_mask)
1766
+
1767
+ if self.inpainting_fill == 0:
1768
+ self.extra_generation_params["Masked content"] = 'fill'
1769
+
1770
+ if add_color_corrections:
1771
+ self.color_corrections.append(setup_color_correction(image))
1772
+
1773
+ image = np.array(image).astype(np.float32) / 255.0
1774
+ image = np.moveaxis(image, 2, 0)
1775
+
1776
+ imgs.append(image)
1777
+
1778
+ if len(imgs) == 1:
1779
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
1780
+ if self.overlay_images is not None:
1781
+ self.overlay_images = self.overlay_images * self.batch_size
1782
+
1783
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
1784
+ self.color_corrections = self.color_corrections * self.batch_size
1785
+
1786
+ elif len(imgs) <= self.batch_size:
1787
+ self.batch_size = len(imgs)
1788
+ batch_images = np.array(imgs)
1789
+ else:
1790
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
1791
+
1792
+ image = torch.from_numpy(batch_images)
1793
+ image = image.to(shared.device, dtype=devices.dtype_vae)
1794
+
1795
+ if opts.sd_vae_encode_method != 'Full':
1796
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1797
+
1798
+ self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1799
+ devices.torch_gc()
1800
+
1801
+ if self.resize_mode == 3:
1802
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
1803
+
1804
+ if image_mask is not None:
1805
+ init_mask = latent_mask
1806
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
1807
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1808
+ latmask = latmask[0]
1809
+ if self.mask_round:
1810
+ latmask = np.around(latmask)
1811
+ latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))
1812
+
1813
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)
1814
+ self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)
1815
+
1816
+ # this needs to be fixed to be done in sample() using actual seeds for batches
1817
+ if self.inpainting_fill == 2:
1818
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1819
+ self.extra_generation_params["Masked content"] = 'latent noise'
1820
+
1821
+ elif self.inpainting_fill == 3:
1822
+ self.init_latent = self.init_latent * self.mask
1823
+ self.extra_generation_params["Masked content"] = 'latent nothing'
1824
+
1825
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
1826
+
1827
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1828
+ x = self.rng.next()
1829
+
1830
+ if self.initial_noise_multiplier != 1.0:
1831
+ self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
1832
+ x *= self.initial_noise_multiplier
1833
+
1834
+ if self.scripts is not None:
1835
+ self.scripts.process_before_every_sampling(
1836
+ p=self,
1837
+ x=self.init_latent,
1838
+ noise=x,
1839
+ c=conditioning,
1840
+ uc=unconditional_conditioning
1841
+ )
1842
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1843
+
1844
+ if self.mask is not None:
1845
+ blended_samples = samples * self.nmask + self.init_latent * self.mask
1846
+
1847
+ if self.scripts is not None:
1848
+ mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
1849
+ self.scripts.on_mask_blend(self, mba)
1850
+ blended_samples = mba.blended_latent
1851
+
1852
+ samples = blended_samples
1853
+
1854
+ del x
1855
+ devices.torch_gc()
1856
+
1857
+ return samples
1858
+
1859
+ def get_token_merging_ratio(self, for_hr=False):
1860
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
sec/sampling.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from scipy import integrate
4
+ import torch
5
+ from torch import nn
6
+ from torchdiffeq import odeint
7
+ import torchsde
8
+ from tqdm.auto import trange, tqdm
9
+
10
+ from . import utils
11
+
12
+
13
+ def append_zero(x):
14
+ return torch.cat([x, x.new_zeros([1])])
15
+
16
+
17
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
18
+ """Constructs the noise schedule of Karras et al. (2022)."""
19
+ ramp = torch.linspace(0, 1, n)
20
+ min_inv_rho = sigma_min ** (1 / rho)
21
+ max_inv_rho = sigma_max ** (1 / rho)
22
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
23
+ return append_zero(sigmas).to(device)
24
+
25
+
26
+ def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
27
+ """Constructs an exponential noise schedule."""
28
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
29
+ return append_zero(sigmas)
30
+
31
+
32
+ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
33
+ """Constructs an polynomial in log sigma noise schedule."""
34
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
35
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
36
+ return append_zero(sigmas)
37
+
38
+
39
+ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
40
+ """Constructs a continuous VP noise schedule."""
41
+ t = torch.linspace(1, eps_s, n, device=device)
42
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
43
+ return append_zero(sigmas)
44
+
45
+
46
+ def to_d(x, sigma, denoised):
47
+ """Converts a denoiser output to a Karras ODE derivative."""
48
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
49
+
50
+
51
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
52
+ """Calculates the noise level (sigma_down) to step down to and the amount
53
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
54
+ if not eta:
55
+ return sigma_to, 0.
56
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
57
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
58
+ return sigma_down, sigma_up
59
+
60
+
61
+ def default_noise_sampler(x):
62
+ return lambda sigma, sigma_next: torch.randn_like(x)
63
+
64
+
65
+ class BatchedBrownianTree:
66
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
67
+
68
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
69
+ t0, t1, self.sign = self.sort(t0, t1)
70
+ w0 = kwargs.get('w0', torch.zeros_like(x))
71
+ if seed is None:
72
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
73
+ self.batched = True
74
+ try:
75
+ assert len(seed) == x.shape[0]
76
+ w0 = w0[0]
77
+ except TypeError:
78
+ seed = [seed]
79
+ self.batched = False
80
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
81
+
82
+ @staticmethod
83
+ def sort(a, b):
84
+ return (a, b, 1) if a < b else (b, a, -1)
85
+
86
+ def __call__(self, t0, t1):
87
+ t0, t1, sign = self.sort(t0, t1)
88
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
89
+ return w if self.batched else w[0]
90
+
91
+
92
+ class BrownianTreeNoiseSampler:
93
+ """A noise sampler backed by a torchsde.BrownianTree.
94
+
95
+ Args:
96
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
97
+ random samples.
98
+ sigma_min (float): The low end of the valid interval.
99
+ sigma_max (float): The high end of the valid interval.
100
+ seed (int or List[int]): The random seed. If a list of seeds is
101
+ supplied instead of a single integer, then the noise sampler will
102
+ use one BrownianTree per batch item, each with its own seed.
103
+ transform (callable): A function that maps sigma to the sampler's
104
+ internal timestep.
105
+ """
106
+
107
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
108
+ self.transform = transform
109
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
110
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
111
+
112
+ def __call__(self, sigma, sigma_next):
113
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
114
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
115
+
116
+
117
+ @torch.no_grad()
118
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
119
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
120
+ extra_args = {} if extra_args is None else extra_args
121
+ s_in = x.new_ones([x.shape[0]])
122
+ for i in trange(len(sigmas) - 1, disable=disable):
123
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
124
+ eps = torch.randn_like(x) * s_noise
125
+ sigma_hat = sigmas[i] * (gamma + 1)
126
+ if gamma > 0:
127
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
128
+ denoised = model(x, sigma_hat * s_in, **extra_args)
129
+ d = to_d(x, sigma_hat, denoised)
130
+ if callback is not None:
131
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
132
+ dt = sigmas[i + 1] - sigma_hat
133
+ # Euler method
134
+ x = x + d * dt
135
+ return x
136
+
137
+
138
+ @torch.no_grad()
139
+ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
140
+ """Ancestral sampling with Euler method steps."""
141
+ extra_args = {} if extra_args is None else extra_args
142
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
143
+ s_in = x.new_ones([x.shape[0]])
144
+ for i in trange(len(sigmas) - 1, disable=disable):
145
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
146
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
147
+ if callback is not None:
148
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
149
+ d = to_d(x, sigmas[i], denoised)
150
+ # Euler method
151
+ dt = sigma_down - sigmas[i]
152
+ x = x + d * dt
153
+ if sigmas[i + 1] > 0:
154
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
155
+ return x
156
+
157
+
158
+ @torch.no_grad()
159
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
160
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
161
+ extra_args = {} if extra_args is None else extra_args
162
+ s_in = x.new_ones([x.shape[0]])
163
+ for i in trange(len(sigmas) - 1, disable=disable):
164
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
165
+ eps = torch.randn_like(x) * s_noise
166
+ sigma_hat = sigmas[i] * (gamma + 1)
167
+ if gamma > 0:
168
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
169
+ denoised = model(x, sigma_hat * s_in, **extra_args)
170
+ d = to_d(x, sigma_hat, denoised)
171
+ if callback is not None:
172
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
173
+ dt = sigmas[i + 1] - sigma_hat
174
+ if sigmas[i + 1] == 0:
175
+ # Euler method
176
+ x = x + d * dt
177
+ else:
178
+ # Heun's method
179
+ x_2 = x + d * dt
180
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
181
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
182
+ d_prime = (d + d_2) / 2
183
+ x = x + d_prime * dt
184
+ return x
185
+
186
+
187
+ @torch.no_grad()
188
+ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
189
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
190
+ extra_args = {} if extra_args is None else extra_args
191
+ s_in = x.new_ones([x.shape[0]])
192
+ for i in trange(len(sigmas) - 1, disable=disable):
193
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
194
+ eps = torch.randn_like(x) * s_noise
195
+ sigma_hat = sigmas[i] * (gamma + 1)
196
+ if gamma > 0:
197
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
198
+ denoised = model(x, sigma_hat * s_in, **extra_args)
199
+ d = to_d(x, sigma_hat, denoised)
200
+ if callback is not None:
201
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
202
+ if sigmas[i + 1] == 0:
203
+ # Euler method
204
+ dt = sigmas[i + 1] - sigma_hat
205
+ x = x + d * dt
206
+ else:
207
+ # DPM-Solver-2
208
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
209
+ dt_1 = sigma_mid - sigma_hat
210
+ dt_2 = sigmas[i + 1] - sigma_hat
211
+ x_2 = x + d * dt_1
212
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
213
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
214
+ x = x + d_2 * dt_2
215
+ return x
216
+
217
+
218
+ @torch.no_grad()
219
+ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
220
+ """Ancestral sampling with DPM-Solver second-order steps."""
221
+ extra_args = {} if extra_args is None else extra_args
222
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
223
+ s_in = x.new_ones([x.shape[0]])
224
+ for i in trange(len(sigmas) - 1, disable=disable):
225
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
226
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
227
+ if callback is not None:
228
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
229
+ d = to_d(x, sigmas[i], denoised)
230
+ if sigma_down == 0:
231
+ # Euler method
232
+ dt = sigma_down - sigmas[i]
233
+ x = x + d * dt
234
+ else:
235
+ # DPM-Solver-2
236
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
237
+ dt_1 = sigma_mid - sigmas[i]
238
+ dt_2 = sigma_down - sigmas[i]
239
+ x_2 = x + d * dt_1
240
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
241
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
242
+ x = x + d_2 * dt_2
243
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
244
+ return x
245
+
246
+
247
+ def linear_multistep_coeff(order, t, i, j):
248
+ if order - 1 > i:
249
+ raise ValueError(f'Order {order} too high for step {i}')
250
+ def fn(tau):
251
+ prod = 1.
252
+ for k in range(order):
253
+ if j == k:
254
+ continue
255
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
256
+ return prod
257
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
258
+
259
+
260
+ @torch.no_grad()
261
+ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
262
+ extra_args = {} if extra_args is None else extra_args
263
+ s_in = x.new_ones([x.shape[0]])
264
+ sigmas_cpu = sigmas.detach().cpu().numpy()
265
+ ds = []
266
+ for i in trange(len(sigmas) - 1, disable=disable):
267
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
268
+ d = to_d(x, sigmas[i], denoised)
269
+ ds.append(d)
270
+ if len(ds) > order:
271
+ ds.pop(0)
272
+ if callback is not None:
273
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
274
+ cur_order = min(i + 1, order)
275
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
276
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
277
+ return x
278
+
279
+
280
+ @torch.no_grad()
281
+ def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
282
+ extra_args = {} if extra_args is None else extra_args
283
+ s_in = x.new_ones([x.shape[0]])
284
+ v = torch.randint_like(x, 2) * 2 - 1
285
+ fevals = 0
286
+ def ode_fn(sigma, x):
287
+ nonlocal fevals
288
+ with torch.enable_grad():
289
+ x = x[0].detach().requires_grad_()
290
+ denoised = model(x, sigma * s_in, **extra_args)
291
+ d = to_d(x, sigma, denoised)
292
+ fevals += 1
293
+ grad = torch.autograd.grad((d * v).sum(), x)[0]
294
+ d_ll = (v * grad).flatten(1).sum(1)
295
+ return d.detach(), d_ll
296
+ x_min = x, x.new_zeros([x.shape[0]])
297
+ t = x.new_tensor([sigma_min, sigma_max])
298
+ sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
299
+ latent, delta_ll = sol[0][-1], sol[1][-1]
300
+ ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
301
+ return ll_prior + delta_ll, {'fevals': fevals}
302
+
303
+
304
+ class PIDStepSizeController:
305
+ """A PID controller for ODE adaptive step size control."""
306
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
307
+ self.h = h
308
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
309
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
310
+ self.b3 = dcoeff / order
311
+ self.accept_safety = accept_safety
312
+ self.eps = eps
313
+ self.errs = []
314
+
315
+ def limiter(self, x):
316
+ return 1 + math.atan(x - 1)
317
+
318
+ def propose_step(self, error):
319
+ inv_error = 1 / (float(error) + self.eps)
320
+ if not self.errs:
321
+ self.errs = [inv_error, inv_error, inv_error]
322
+ self.errs[0] = inv_error
323
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
324
+ factor = self.limiter(factor)
325
+ accept = factor >= self.accept_safety
326
+ if accept:
327
+ self.errs[2] = self.errs[1]
328
+ self.errs[1] = self.errs[0]
329
+ self.h *= factor
330
+ return accept
331
+
332
+
333
+ class DPMSolver(nn.Module):
334
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
335
+
336
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
337
+ super().__init__()
338
+ self.model = model
339
+ self.extra_args = {} if extra_args is None else extra_args
340
+ self.eps_callback = eps_callback
341
+ self.info_callback = info_callback
342
+
343
+ def t(self, sigma):
344
+ return -sigma.log()
345
+
346
+ def sigma(self, t):
347
+ return t.neg().exp()
348
+
349
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
350
+ if key in eps_cache:
351
+ return eps_cache[key], eps_cache
352
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
353
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
354
+ if self.eps_callback is not None:
355
+ self.eps_callback()
356
+ return eps, {key: eps, **eps_cache}
357
+
358
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
359
+ eps_cache = {} if eps_cache is None else eps_cache
360
+ h = t_next - t
361
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
362
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
363
+ return x_1, eps_cache
364
+
365
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
366
+ eps_cache = {} if eps_cache is None else eps_cache
367
+ h = t_next - t
368
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
369
+ s1 = t + r1 * h
370
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
371
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
372
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
373
+ return x_2, eps_cache
374
+
375
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
376
+ eps_cache = {} if eps_cache is None else eps_cache
377
+ h = t_next - t
378
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
379
+ s1 = t + r1 * h
380
+ s2 = t + r2 * h
381
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
382
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
383
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
384
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
385
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
386
+ return x_3, eps_cache
387
+
388
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
389
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
390
+ if not t_end > t_start and eta:
391
+ raise ValueError('eta must be 0 for reverse sampling')
392
+
393
+ m = math.floor(nfe / 3) + 1
394
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
395
+
396
+ if nfe % 3 == 0:
397
+ orders = [3] * (m - 2) + [2, 1]
398
+ else:
399
+ orders = [3] * (m - 1) + [nfe % 3]
400
+
401
+ for i in range(len(orders)):
402
+ eps_cache = {}
403
+ t, t_next = ts[i], ts[i + 1]
404
+ if eta:
405
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
406
+ t_next_ = torch.minimum(t_end, self.t(sd))
407
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
408
+ else:
409
+ t_next_, su = t_next, 0.
410
+
411
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
412
+ denoised = x - self.sigma(t) * eps
413
+ if self.info_callback is not None:
414
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
415
+
416
+ if orders[i] == 1:
417
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
418
+ elif orders[i] == 2:
419
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
420
+ else:
421
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
422
+
423
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
424
+
425
+ return x
426
+
427
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
428
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
429
+ if order not in {2, 3}:
430
+ raise ValueError('order should be 2 or 3')
431
+ forward = t_end > t_start
432
+ if not forward and eta:
433
+ raise ValueError('eta must be 0 for reverse sampling')
434
+ h_init = abs(h_init) * (1 if forward else -1)
435
+ atol = torch.tensor(atol)
436
+ rtol = torch.tensor(rtol)
437
+ s = t_start
438
+ x_prev = x
439
+ accept = True
440
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
441
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
442
+
443
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
444
+ eps_cache = {}
445
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
446
+ if eta:
447
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
448
+ t_ = torch.minimum(t_end, self.t(sd))
449
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
450
+ else:
451
+ t_, su = t, 0.
452
+
453
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
454
+ denoised = x - self.sigma(s) * eps
455
+
456
+ if order == 2:
457
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
458
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
459
+ else:
460
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
461
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
462
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
463
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
464
+ accept = pid.propose_step(error)
465
+ if accept:
466
+ x_prev = x_low
467
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
468
+ s = t
469
+ info['n_accept'] += 1
470
+ else:
471
+ info['n_reject'] += 1
472
+ info['nfe'] += order
473
+ info['steps'] += 1
474
+
475
+ if self.info_callback is not None:
476
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
477
+
478
+ return x, info
479
+
480
+
481
+ @torch.no_grad()
482
+ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
483
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
484
+ if sigma_min <= 0 or sigma_max <= 0:
485
+ raise ValueError('sigma_min and sigma_max must not be 0')
486
+ with tqdm(total=n, disable=disable) as pbar:
487
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
488
+ if callback is not None:
489
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
490
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
491
+
492
+
493
+ @torch.no_grad()
494
+ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
495
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
496
+ if sigma_min <= 0 or sigma_max <= 0:
497
+ raise ValueError('sigma_min and sigma_max must not be 0')
498
+ with tqdm(disable=disable) as pbar:
499
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
500
+ if callback is not None:
501
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
502
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
503
+ if return_info:
504
+ return x, info
505
+ return x
506
+
507
+
508
+ @torch.no_grad()
509
+ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
510
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
511
+ extra_args = {} if extra_args is None else extra_args
512
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
513
+ s_in = x.new_ones([x.shape[0]])
514
+ sigma_fn = lambda t: t.neg().exp()
515
+ t_fn = lambda sigma: sigma.log().neg()
516
+
517
+ for i in trange(len(sigmas) - 1, disable=disable):
518
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
519
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
520
+ if callback is not None:
521
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
522
+ if sigma_down == 0:
523
+ # Euler method
524
+ d = to_d(x, sigmas[i], denoised)
525
+ dt = sigma_down - sigmas[i]
526
+ x = x + d * dt
527
+ else:
528
+ # DPM-Solver++(2S)
529
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
530
+ r = 1 / 2
531
+ h = t_next - t
532
+ s = t + r * h
533
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
534
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
535
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
536
+ # Noise addition
537
+ if sigmas[i + 1] > 0:
538
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
539
+ return x
540
+
541
+
542
+ @torch.no_grad()
543
+ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
544
+ """DPM-Solver++ (stochastic)."""
545
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
546
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
547
+ extra_args = {} if extra_args is None else extra_args
548
+ s_in = x.new_ones([x.shape[0]])
549
+ sigma_fn = lambda t: t.neg().exp()
550
+ t_fn = lambda sigma: sigma.log().neg()
551
+
552
+ for i in trange(len(sigmas) - 1, disable=disable):
553
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
554
+ if callback is not None:
555
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
556
+ if sigmas[i + 1] == 0:
557
+ # Euler method
558
+ d = to_d(x, sigmas[i], denoised)
559
+ dt = sigmas[i + 1] - sigmas[i]
560
+ x = x + d * dt
561
+ else:
562
+ # DPM-Solver++
563
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
564
+ h = t_next - t
565
+ s = t + h * r
566
+ fac = 1 / (2 * r)
567
+
568
+ # Step 1
569
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
570
+ s_ = t_fn(sd)
571
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
572
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
573
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
574
+
575
+ # Step 2
576
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
577
+ t_next_ = t_fn(sd)
578
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
579
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
580
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
581
+ return x
582
+
583
+
584
+ @torch.no_grad()
585
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
586
+ """DPM-Solver++(2M)."""
587
+ extra_args = {} if extra_args is None else extra_args
588
+ s_in = x.new_ones([x.shape[0]])
589
+ sigma_fn = lambda t: t.neg().exp()
590
+ t_fn = lambda sigma: sigma.log().neg()
591
+ old_denoised = None
592
+
593
+ for i in trange(len(sigmas) - 1, disable=disable):
594
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
595
+ if callback is not None:
596
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
597
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
598
+ h = t_next - t
599
+ if old_denoised is None or sigmas[i + 1] == 0:
600
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
601
+ else:
602
+ h_last = t - t_fn(sigmas[i - 1])
603
+ r = h_last / h
604
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
605
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
606
+ old_denoised = denoised
607
+ return x
608
+
609
+
610
+ @torch.no_grad()
611
+ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
612
+ """DPM-Solver++(2M) SDE."""
613
+
614
+ if solver_type not in {'heun', 'midpoint'}:
615
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
616
+
617
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
618
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
619
+ extra_args = {} if extra_args is None else extra_args
620
+ s_in = x.new_ones([x.shape[0]])
621
+
622
+ old_denoised = None
623
+ h_last = None
624
+
625
+ for i in trange(len(sigmas) - 1, disable=disable):
626
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
627
+ if callback is not None:
628
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
629
+ if sigmas[i + 1] == 0:
630
+ # Denoising step
631
+ x = denoised
632
+ else:
633
+ # DPM-Solver++(2M) SDE
634
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
635
+ h = s - t
636
+ eta_h = eta * h
637
+
638
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
639
+
640
+ if old_denoised is not None:
641
+ r = h_last / h
642
+ if solver_type == 'heun':
643
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
644
+ elif solver_type == 'midpoint':
645
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
646
+
647
+ if eta:
648
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
649
+
650
+ old_denoised = denoised
651
+ h_last = h
652
+ return x
653
+
654
+
655
+ @torch.no_grad()
656
+ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
657
+ """DPM-Solver++(3M) SDE."""
658
+
659
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
660
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
661
+ extra_args = {} if extra_args is None else extra_args
662
+ s_in = x.new_ones([x.shape[0]])
663
+
664
+ denoised_1, denoised_2 = None, None
665
+ h_1, h_2 = None, None
666
+
667
+ for i in trange(len(sigmas) - 1, disable=disable):
668
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
669
+ if callback is not None:
670
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
671
+ if sigmas[i + 1] == 0:
672
+ # Denoising step
673
+ x = denoised
674
+ else:
675
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
676
+ h = s - t
677
+ h_eta = h * (eta + 1)
678
+
679
+ x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
680
+
681
+ if h_2 is not None:
682
+ r0 = h_1 / h
683
+ r1 = h_2 / h
684
+ d1_0 = (denoised - denoised_1) / r0
685
+ d1_1 = (denoised_1 - denoised_2) / r1
686
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
687
+ d2 = (d1_0 - d1_1) / (r0 + r1)
688
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
689
+ phi_3 = phi_2 / h_eta - 0.5
690
+ x = x + phi_2 * d1 - phi_3 * d2
691
+ elif h_1 is not None:
692
+ r = h_1 / h
693
+ d = (denoised - denoised_1) / r
694
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
695
+ x = x + phi_2 * d
696
+
697
+ if eta:
698
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
699
+
700
+ denoised_1, denoised_2 = denoised, denoised_1
701
+ h_1, h_2 = h, h_1
702
+ return x
703
+
704
+ @torch.no_grad()
705
+ def sampler_dpmu(model, x, sigmas, extra_args=None, callback=None, disable=None):
706
+ extra_args = {} if extra_args is None else extra_args
707
+ s_in = x.new_ones([x.shape[0]])
708
+ sigma_fn = lambda t: t.neg().exp()
709
+ t_fn = lambda sigma: sigma.log().neg()
710
+ last_x = None
711
+ for i in trange(len(sigmas) - 1, disable=disable):
712
+ denoised = x if i == 0 else model(x, sigmas[i] * s_in, **extra_args)
713
+ if callback is not None:
714
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
715
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
716
+ h = t_next - t
717
+ if sigmas[i + 1] == 0:
718
+ return torch.lerp(denoised, last_x, 0.5) * dpmu_factor
719
+ else:
720
+ h_last = t - t_fn(sigmas[i - 1])
721
+ r = h_last / h
722
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * (1 + 1 / (2 * r)) * denoised / 2
723
+ if sigmas[i + 2] == 0:
724
+ last_x = x
725
+ torch.clamp(x, -1.0, 1.0)
726
+ return x
sec/sd_schedulers.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ import k_diffusion
4
+ import numpy as np
5
+ from scipy import stats
6
+ import modules.simple_karras_exponential_scheduler as simple_kes
7
+ from modules import shared
8
+
9
+
10
+ def to_d(x, sigma, denoised):
11
+ """Converts a denoiser output to a Karras ODE derivative."""
12
+ return (x - denoised) / sigma
13
+
14
+
15
+ k_diffusion.sampling.to_d = to_d
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Scheduler:
20
+ name: str
21
+ label: str
22
+ function: any
23
+
24
+ default_rho: float = -1
25
+ need_inner_model: bool = False
26
+ aliases: list = None
27
+
28
+
29
+ def uniform(n, sigma_min, sigma_max, inner_model, device):
30
+ return inner_model.get_sigmas(n).to(device)
31
+
32
+
33
+ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
34
+ start = inner_model.sigma_to_t(torch.tensor(sigma_max))
35
+ end = inner_model.sigma_to_t(torch.tensor(sigma_min))
36
+ sigs = [
37
+ inner_model.t_to_sigma(ts)
38
+ for ts in torch.linspace(start, end, n + 1)[:-1]
39
+ ]
40
+ sigs += [0.0]
41
+ return torch.FloatTensor(sigs).to(device)
42
+
43
+
44
+ def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device):
45
+ # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
46
+ def loglinear_interp(t_steps, num_steps):
47
+ """
48
+ Performs log-linear interpolation of a given array of decreasing numbers.
49
+ """
50
+ xs = np.linspace(0, 1, len(t_steps))
51
+ ys = np.log(t_steps[::-1])
52
+
53
+ new_xs = np.linspace(0, 1, num_steps)
54
+ new_ys = np.interp(new_xs, xs, ys)
55
+
56
+ interped_ys = np.exp(new_ys)[::-1].copy()
57
+ return interped_ys
58
+
59
+ if shared.sd_model.is_sdxl:
60
+ sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
61
+ else:
62
+ # Default to SD 1.5 sigmas.
63
+ sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]
64
+
65
+ if n != len(sigmas):
66
+ sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
67
+ else:
68
+ sigmas.append(0.0)
69
+
70
+ return torch.FloatTensor(sigmas).to(device)
71
+
72
+
73
+ def kl_optimal(n, sigma_min, sigma_max, device):
74
+ alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))
75
+ alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))
76
+ step_indices = torch.arange(n + 1, device=device)
77
+ sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
78
+ return sigmas
79
+
80
+
81
+ def simple_scheduler(n, sigma_min, sigma_max, inner_model, device):
82
+ sigs = []
83
+ ss = len(inner_model.sigmas) / n
84
+ for x in range(n):
85
+ sigs += [float(inner_model.sigmas[-(1 + int(x * ss))])]
86
+ sigs += [0.0]
87
+ return torch.FloatTensor(sigs).to(device)
88
+
89
+
90
+ def normal_scheduler(n, sigma_min, sigma_max, inner_model, device, sgm=False, floor=False):
91
+ start = inner_model.sigma_to_t(torch.tensor(sigma_max))
92
+ end = inner_model.sigma_to_t(torch.tensor(sigma_min))
93
+
94
+ if sgm:
95
+ timesteps = torch.linspace(start, end, n + 1)[:-1]
96
+ else:
97
+ timesteps = torch.linspace(start, end, n)
98
+
99
+ sigs = []
100
+ for x in range(len(timesteps)):
101
+ ts = timesteps[x]
102
+ sigs.append(inner_model.t_to_sigma(ts))
103
+ sigs += [0.0]
104
+ return torch.FloatTensor(sigs).to(device)
105
+
106
+
107
+ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
108
+ sigs = []
109
+ ss = max(len(inner_model.sigmas) // n, 1)
110
+ x = 1
111
+ while x < len(inner_model.sigmas):
112
+ sigs += [float(inner_model.sigmas[x])]
113
+ x += ss
114
+ sigs = sigs[::-1]
115
+ sigs += [0.0]
116
+ return torch.FloatTensor(sigs).to(device)
117
+
118
+
119
+ def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
120
+ # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
121
+ alpha = shared.opts.beta_dist_alpha
122
+ beta = shared.opts.beta_dist_beta
123
+ timesteps = 1 - np.linspace(0, 1, n)
124
+ timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
125
+ sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
126
+ sigmas += [0.0]
127
+ return torch.FloatTensor(sigmas).to(device)
128
+
129
+
130
+ schedulers = [
131
+ Scheduler('automatic', 'Automatic', None),
132
+ Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
133
+ Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
134
+ Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
135
+ Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
136
+ Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
137
+ Scheduler('kl_optimal', 'KL Optimal', kl_optimal),
138
+ Scheduler('align_your_steps', 'Align Your Steps', get_align_your_steps_sigmas),
139
+ Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
140
+ Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
141
+ Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
142
+ Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True),
143
+ Scheduler('karras_exponential', 'Karras Exponential', simple_kes.simple_karras_exponential_scheduler),
144
+ ]
145
+
146
+ schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
sec/simple_karras_exponential_scheduler.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #simple_karras_exponential_scheduler.py
2
+ import torch
3
+ import logging
4
+ from k_diffusion.sampling import get_sigmas_karras, get_sigmas_exponential
5
+ import os
6
+ import yaml
7
+ import random
8
+ from watchdog.observers import Observer
9
+ from watchdog.events import FileSystemEventHandler
10
+ from datetime import datetime
11
+ import warnings
12
+ import os
13
+ import logging
14
+ from datetime import datetime
15
+ def get_random_or_default(scheduler_config, key_prefix, default_value, global_randomize):
16
+ """Helper function to either randomize a value based on conditions or return the default."""
17
+
18
+ # Determine if we should randomize based on global and individual flags
19
+ randomize_flag = global_randomize or scheduler_config.get(f'{key_prefix}_rand', False)
20
+
21
+ if randomize_flag:
22
+ # Use specified min/max values for randomization if they exist, else use default range
23
+ rand_min = scheduler_config.get(f'{key_prefix}_rand_min', default_value * 0.8)
24
+ rand_max = scheduler_config.get(f'{key_prefix}_rand_max', default_value * 1.2)
25
+ value = random.uniform(rand_min, rand_max)
26
+ custom_logger.info(f"Randomized {key_prefix}: {value}")
27
+ else:
28
+ # Use default value if no randomization is applied
29
+ value = default_value
30
+ custom_logger.info(f"Using default {key_prefix}: {value}")
31
+
32
+ return value
33
+
34
+
35
+ class CustomLogger:
36
+ def __init__(self, log_name, print_to_console=False, debug_enabled=False):
37
+ self.print_to_console = print_to_console #prints to console
38
+ self.debug_enabled = debug_enabled #logs debug messages
39
+
40
+ # Create folders for generation info and error logs
41
+ gen_log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'simple_kes_generation')
42
+ error_log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'simple_kes_error')
43
+
44
+ os.makedirs(gen_log_dir, exist_ok=True)
45
+ os.makedirs(error_log_dir, exist_ok=True)
46
+
47
+ # Get current time in HH-MM-SS format
48
+ current_time = datetime.now().strftime('%H-%M-%S')
49
+
50
+ # Create file paths for the log files
51
+ gen_log_file_path = os.path.join(gen_log_dir, f'{current_time}.log')
52
+ error_log_file_path = os.path.join(error_log_dir, f'{current_time}.log')
53
+
54
+ # Set up generation logger
55
+ #self.gen_logger = logging.getLogger(f'{log_name}_generation')
56
+ self.gen_logger = logging.getLogger('simple_kes_generation')
57
+ self.gen_logger.setLevel(logging.DEBUG)
58
+ self._setup_file_handler(self.gen_logger, gen_log_file_path)
59
+
60
+ # Set up error logger
61
+ self.error_logger = logging.getLogger(f'{log_name}_error')
62
+ self.error_logger.setLevel(logging.ERROR)
63
+ self._setup_file_handler(self.error_logger, error_log_file_path)
64
+
65
+ # Prevent log propagation to root logger (important to avoid accidental console logging)
66
+ self.gen_logger.propagate = False
67
+ self.error_logger.propagate = False
68
+
69
+
70
+ # Optionally print to console
71
+ if self.print_to_console:
72
+ self._setup_console_handler(self.gen_logger)
73
+ self._setup_console_handler(self.error_logger)
74
+
75
+ def _setup_file_handler(self, logger, file_path):
76
+ """Set up file handler for logging to a file."""
77
+ file_handler = logging.FileHandler(file_path, mode='a')
78
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
79
+ file_handler.setFormatter(formatter)
80
+ logger.addHandler(file_handler)
81
+
82
+ def _setup_console_handler(self, logger):
83
+ """Optionally set up a console handler for logging to the console."""
84
+ console_handler = logging.StreamHandler()
85
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
86
+ console_handler.setFormatter(formatter)
87
+ logger.addHandler(console_handler)
88
+
89
+ def log_debug(self, message):
90
+ """Log a debug message."""
91
+ if self.debug_enabled:
92
+ self.gen_logger.debug(message)
93
+
94
+ def log_info(self, message):
95
+ """Log an info message."""
96
+ self.gen_logger.info(message)
97
+ info=log_info #alias created
98
+
99
+ def log_error(self, message):
100
+ """Log an error message."""
101
+ self.error_logger.error(message)
102
+
103
+ def enable_console_logging(self):
104
+ """Enable console logging dynamically."""
105
+ if not any(isinstance(handler, logging.StreamHandler) for handler in self.gen_logger.handlers):
106
+ self._setup_console_handler(self.gen_logger)
107
+
108
+ if not any(isinstance(handler, logging.StreamHandler) for handler in self.error_logger.handlers):
109
+ self._setup_console_handler(self.error_logger)
110
+
111
+ # Usage example
112
+ custom_logger = CustomLogger('simple_kes', print_to_console=False, debug_enabled=True)
113
+
114
+ # Logging examples
115
+ #custom_logger.log_debug("Debug message: Using default sigma_min: 0.01")
116
+ #custom_logger.info("Info message: Step completed successfully.")
117
+ #custom_logger.log_error("Error message: Something went wrong!")
118
+
119
+
120
+ class ConfigManagerYaml:
121
+ def __init__(self, config_path):
122
+ self.config_path = config_path
123
+ self.config_data = self.load_config() # Initialize config_data here
124
+
125
+ def load_config(self):
126
+ try:
127
+ with open(self.config_path, 'r') as f:
128
+ user_config = yaml.safe_load(f)
129
+ return user_config
130
+ except FileNotFoundError:
131
+ print(f"Config file not found: {self.config_path}. Using empty config.")
132
+ return {}
133
+ except yaml.YAMLError as e:
134
+ print(f"Error loading config file: {e}")
135
+ return {}
136
+
137
+
138
+ #ConfigWatcher monitors changes to the config file and reloads during program use (so you can continue work without resetting the program)
139
+ class ConfigWatcher(FileSystemEventHandler):
140
+ def __init__(self, config_manager, config_path):
141
+ self.config_manager = config_manager
142
+ self.config_path = config_path
143
+
144
+ def on_modified(self, event):
145
+ if event.src_path == self.config_path:
146
+ logging.info(f"Config file {self.config_path} modified. Reloading config.")
147
+ self.config_manager.config_data = self.config_manager.load_config()
148
+
149
+
150
+
151
+ def start_config_watcher(config_manager, config_path):
152
+ event_handler = ConfigWatcher(config_manager, config_path)
153
+ observer = Observer()
154
+ observer.schedule(event_handler, os.path.dirname(config_path), recursive=False)
155
+ observer.start()
156
+ return observer
157
+
158
+
159
+ """
160
+ Scheduler function that blends sigma sequences using Karras and Exponential methods with adaptive parameters.
161
+
162
+ Parameters are dynamically updated if the config file changes during execution.
163
+ """
164
+ # If user config is provided, update default config with user values
165
+ config_path = "modules/simple_kes_scheduler.yaml"
166
+ config_manager = ConfigManagerYaml(config_path)
167
+
168
+
169
+ # Start watching for config changes
170
+ observer = start_config_watcher(config_manager, config_path)
171
+
172
+
173
+ def simple_karras_exponential_scheduler(
174
+ n, device, sigma_min=0.01, sigma_max=50, start_blend=0.1, end_blend=0.5,
175
+ sharpness=0.95, early_stopping_threshold=0.01, update_interval=10, initial_step_size=0.9,
176
+ final_step_size=0.2, initial_noise_scale=1.25, final_noise_scale=0.8, smooth_blend_factor=11, step_size_factor=0.8, noise_scale_factor=0.9, randomize=False, user_config=None
177
+ ):
178
+ """
179
+ Scheduler function that blends sigma sequences using Karras and Exponential methods with adaptive parameters.
180
+
181
+ Parameters:
182
+ n (int): Number of steps.
183
+ sigma_min (float): Minimum sigma value.
184
+ sigma_max (float): Maximum sigma value.
185
+ device (torch.device): The device on which to perform computations (e.g., 'cuda' or 'cpu').
186
+ start_blend (float): Initial blend factor for dynamic blending.
187
+ end_bend (float): Final blend factor for dynamic blending.
188
+ sharpen_factor (float): Sharpening factor to be applied adaptively.
189
+ early_stopping_threshold (float): Threshold to trigger early stopping.
190
+ update_interval (int): Interval to update blend factors.
191
+ initial_step_size (float): Initial step size for adaptive step size calculation.
192
+ final_step_size (float): Final step size for adaptive step size calculation.
193
+ initial_noise_scale (float): Initial noise scale factor.
194
+ final_noise_scale (float): Final noise scale factor.
195
+ step_size_factor: Adjust to compensate for oversmoothing
196
+ noise_scale_factor: Adjust to provide more variation
197
+
198
+ Returns:
199
+ torch.Tensor: A tensor of blended sigma values.
200
+ """
201
+ config_path = os.path.join(os.path.dirname(__file__), 'simple_kes_scheduler.yaml')
202
+ config = config_manager.load_config()
203
+ scheduler_config = config.get('scheduler', {})
204
+ if not scheduler_config:
205
+ warnings.warn("Scheduler configuration is missing from the config file. Using default values.")
206
+
207
+ # Global randomization flag
208
+ global_randomize = scheduler_config.get('randomize', False)
209
+
210
+ #debug_log("Entered simple_karras_exponential_scheduler function")
211
+ default_config = {
212
+ "debug": False,
213
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
214
+ "sigma_min": 0.01,
215
+ "sigma_max": 50, #if sigma_max is too low the resulting picture may be undesirable.
216
+ "start_blend": 0.1,
217
+ "end_blend": 0.5,
218
+ "sharpness": 0.95,
219
+ "early_stopping_threshold": 0.01,
220
+ "update_interval": 10,
221
+ "initial_step_size": 0.9,
222
+ "final_step_size": 0.2,
223
+ "initial_noise_scale": 1.25,
224
+ "final_noise_scale": 0.8,
225
+ "smooth_blend_factor": 11,
226
+ "step_size_factor": 0.8, #suggested value to avoid oversmoothing
227
+ "noise_scale_factor": 0.9, #suggested value to add more variation
228
+ "randomize": False,
229
+ "sigma_min_rand": False,
230
+ "sigma_min_rand_min": 0.001,
231
+ "sigma_min_rand_max": 0.05,
232
+ "sigma_max_rand": False,
233
+ "sigma_max_rand_min": 0.05,
234
+ "sigma_max_rand_max": 0.20,
235
+ "start_blend_rand": False,
236
+ "start_blend_rand_min": 0.05,
237
+ "start_blend_rand_max": 0.2,
238
+ "end_blend_rand": False,
239
+ "end_blend_rand_min": 0.4,
240
+ "end_blend_rand_max": 0.6,
241
+ "sharpness_rand": False,
242
+ "sharpness_rand_min": 0.85,
243
+ "sharpness_rand_max": 1.0,
244
+ "early_stopping_rand": False,
245
+ "early_stopping_rand_min": 0.001,
246
+ "early_stopping_rand_max": 0.02,
247
+ "update_interval_rand": False,
248
+ "update_interval_rand_min": 5,
249
+ "update_interval_rand_max": 10,
250
+ "initial_step_rand": False,
251
+ "initial_step_rand_min": 0.7,
252
+ "initial_step_rand_max": 1.0,
253
+ "final_step_rand": False,
254
+ "final_step_rand_min": 0.1,
255
+ "final_step_rand_max": 0.3,
256
+ "initial_noise_rand": False,
257
+ "initial_noise_rand_min": 1.0,
258
+ "initial_noise_rand_max": 1.5,
259
+ "final_noise_rand": False,
260
+ "final_noise_rand_min": 0.6,
261
+ "final_noise_rand_max": 1.0,
262
+ "smooth_blend_factor_rand": False,
263
+ "smooth_blend_factor_rand_min": 6,
264
+ "smooth_blend_factor_rand_max": 11,
265
+ "step_size_factor_rand": False,
266
+ "step_size_factor_rand_min": 0.65,
267
+ "step_size_factor_rand_max": 0.85,
268
+ "noise_scale_factor_rand": False,
269
+ "noise_scale_factor_rand_min": 0.75,
270
+ "noise_scale_factor_rand_max": 0.95,
271
+ }
272
+ custom_logger.info(f"Default Config create {default_config}")
273
+ config = config_manager.load_config().get('scheduler', {})
274
+ if not config:
275
+ warnings.warn("Scheduler configuration is missing from the config file.")
276
+
277
+ # Log loaded YAML configuration
278
+ custom_logger.info(f"Configuration loaded from YAML: {config}")
279
+
280
+ for key, value in config.items():
281
+ if key in default_config:
282
+ default_config[key] = value # Override default with YAML value
283
+ custom_logger.info(f"Overriding default config: {key} = {value}")
284
+ else:
285
+ custom_logger.info(f"Ignoring unknown config option: {key}")
286
+
287
+ custom_logger.info(f"Final configuration after merging with YAML: {default_config}")
288
+
289
+ global_randomize = default_config.get('randomize', False)
290
+ custom_logger.info(f"Global randomization flag set to: {global_randomize}")
291
+
292
+ custom_logger.info(f"Config loaded from yaml {config}")
293
+
294
+ # Now using default_config, updated with valid YAML values
295
+ custom_logger.info(f"Final Config after overriding: {default_config}")
296
+
297
+ # Example: Reading the randomization flags from the config
298
+ randomize = config.get('scheduler', {}).get('randomize', False)
299
+
300
+ # Use the get_random_or_default function for each parameter
301
+ #if randomize = false, then it checks for each variable for randomize, if true, then that particular option is randomized, with the others using default or config defined values.
302
+ sigma_min = get_random_or_default(config, 'sigma_min', sigma_min, global_randomize)
303
+ sigma_max = get_random_or_default(config, 'sigma_max', sigma_max, global_randomize)
304
+ start_blend = get_random_or_default(config, 'start_blend', start_blend, global_randomize)
305
+ end_blend = get_random_or_default(config, 'end_blend', end_blend, global_randomize)
306
+ sharpness = get_random_or_default(config, 'sharpness', sharpness, global_randomize)
307
+ early_stopping_threshold = get_random_or_default(config, 'early_stopping', early_stopping_threshold, global_randomize)
308
+ update_interval = get_random_or_default(config, 'update_interval', update_interval, global_randomize)
309
+ initial_step_size = get_random_or_default(config, 'initial_step', initial_step_size, global_randomize)
310
+ final_step_size = get_random_or_default(config, 'final_step', final_step_size, global_randomize)
311
+ initial_noise_scale = get_random_or_default(config, 'initial_noise', initial_noise_scale, global_randomize)
312
+ final_noise_scale = get_random_or_default(config, 'final_noise', final_noise_scale, global_randomize)
313
+ smooth_blend_factor = get_random_or_default(config, 'smooth_blend_factor', smooth_blend_factor, global_randomize)
314
+ step_size_factor = get_random_or_default(config, 'step_size_factor', step_size_factor, global_randomize)
315
+ noise_scale_factor = get_random_or_default(config, 'noise_scale_factor', noise_scale_factor, global_randomize)
316
+
317
+
318
+ # Expand sigma_max slightly to account for smoother transitions
319
+ sigma_max = sigma_max * 1.1
320
+ custom_logger.info(f"Using device: {device}")
321
+ # Generate sigma sequences using Karras and Exponential methods
322
+ sigmas_karras = get_sigmas_karras(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)
323
+ sigmas_exponential = get_sigmas_exponential(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)
324
+ config = config_manager.config_data.get('scheduler', {})
325
+ # Match lengths of sigma sequences
326
+ target_length = min(len(sigmas_karras), len(sigmas_exponential))
327
+ sigmas_karras = sigmas_karras[:target_length]
328
+ sigmas_exponential = sigmas_exponential[:target_length]
329
+
330
+ custom_logger.info(f"Generated sigma sequences. Karras: {sigmas_karras}, Exponential: {sigmas_exponential}")
331
+ if sigmas_karras is None:
332
+ raise ValueError("Sigmas Karras:{sigmas_karras} Failed to generate or assign sigmas correctly.")
333
+ if sigmas_exponential is None:
334
+ raise ValueError("Sigmas Exponential: {sigmas_exponential} Failed to generate or assign sigmas correctly.")
335
+ #sigmas_karras = torch.zeros(n).to(device)
336
+ #sigmas_exponential = torch.zeros(n).to(device)
337
+ try:
338
+ pass
339
+ except Exception as e:
340
+ error_log(f"Error generating sigmas: {e}")
341
+ finally:
342
+ # Stop the observer when done
343
+ observer.stop()
344
+ observer.join()
345
+
346
+ # Define progress and initialize blend factor
347
+ progress = torch.linspace(0, 1, len(sigmas_karras)).to(device)
348
+ custom_logger.info(f"Progress created {progress}")
349
+ custom_logger.info(f"Progress Using device: {device}")
350
+
351
+ sigs = torch.zeros_like(sigmas_karras).to(device)
352
+ custom_logger.info(f"Sigs created {sigs}")
353
+ custom_logger.info(f"Sigs Using device: {device}")
354
+
355
+ # Iterate through each step, dynamically adjust blend factor, step size, and noise scaling
356
+ for i in range(len(sigmas_karras)):
357
+ # Adaptive step size and blend factor calculations
358
+ step_size = initial_step_size * (1 - progress[i]) + final_step_size * progress[i] * step_size_factor # 0.8 default value Adjusted to avoid over-smoothing
359
+ custom_logger.info(f"Step_size created {step_size}" )
360
+ dynamic_blend_factor = start_blend * (1 - progress[i]) + end_blend * progress[i]
361
+ custom_logger.info(f"Dynamic_blend_factor created {dynamic_blend_factor}" )
362
+ noise_scale = initial_noise_scale * (1 - progress[i]) + final_noise_scale * progress[i] * noise_scale_factor # 0.9 default value Adjusted to keep more variation
363
+ custom_logger.info(f"noise_scale created {noise_scale}" )
364
+
365
+ # Calculate smooth blending between the two sigma sequences
366
+ smooth_blend = torch.sigmoid((dynamic_blend_factor - 0.5) * smooth_blend_factor) # Increase scaling factor to smooth transitions more
367
+ custom_logger.info(f"smooth_blend created {smooth_blend}" )
368
+
369
+ # Compute blended sigma values
370
+ blended_sigma = sigmas_karras[i] * (1 - smooth_blend) + sigmas_exponential[i] * smooth_blend
371
+ custom_logger.info(f"blended_sigma created {blended_sigma}" )
372
+
373
+ # Apply step size and noise scaling
374
+ sigs[i] = blended_sigma * step_size * noise_scale
375
+
376
+ # Optional: Adaptive sharpening based on sigma values
377
+ sharpen_mask = torch.where(sigs < sigma_min * 1.5, sharpness, 1.0).to(device)
378
+ custom_logger.info(f"sharpen_mask created {sharpen_mask} with device {device}" )
379
+ sigs = sigs * sharpen_mask
380
+
381
+ # Implement early stop criteria based on sigma convergence
382
+ change = torch.abs(sigs[1:] - sigs[:-1])
383
+ if torch.all(change < early_stopping_threshold):
384
+ custom_logger.info("Early stopping criteria met." )
385
+ return sigs[:len(change) + 1].to(device)
386
+
387
+ if torch.isnan(sigs).any() or torch.isinf(sigs).any():
388
+ raise ValueError("Invalid sigma values detected (NaN or Inf).")
389
+
390
+ return sigs.to(device)
sec/simple_kes_scheduler.yaml ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scheduler:
2
+
3
+ #Optionally print to a log file for debugging. If false, debug is turned off, and no log file will be created.
4
+ #config options: true or false
5
+ debug: false
6
+
7
+ # The minimum value for the noise level (sigma) during image generation.
8
+ # Decreasing this value makes the image clearer but less detailed.
9
+ # Increasing it makes the image noisier but potentially more artistic or abstract.
10
+ sigma_min: 0.01 # Default: 0.01, Suggested range: 0.01 - 0.1
11
+
12
+ # The maximum value for the noise level (sigma) during image generation.
13
+ # Increasing this value can create more variation in the image details.
14
+ # Lower values keep the image more stable and less noisy.
15
+ sigma_max: 50 # Default: 50, Suggested range:10 - 60
16
+
17
+ # The device used for running the scheduler. If you have a GPU, set this to "cuda".
18
+ # Otherwise, use "cpu", but note that it will be significantly slower.
19
+ #device: "cuda" # Options: "cuda" (GPU) or "cpu" (processor)
20
+
21
+ # Initial blend factor between Karras and Exponential noise methods.
22
+ # A higher initial blend makes the image sharper at the start.
23
+ # A lower initial blend makes the image smoother early on.
24
+ start_blend: 0.1 # Default: 0.1, Suggested range: 0.05 - 0.2
25
+
26
+ # Final blend factor between Karras and Exponential noise methods.
27
+ # Higher values blend more noise at the end, possibly adding more detail.
28
+ # Lower values blend less noise for smoother, simpler images at the end.
29
+ end_blend: 0.5 # Default: 0.5, Suggested range: 0.4 - 0.6
30
+
31
+ # Sharpening factor applied to images during generation.
32
+ # Higher values increase sharpness but can add unwanted artifacts.
33
+ # Lower values reduce sharpness but may make the image look blurry.
34
+ sharpness: 0.95 # Default: 0.95, Suggested range: 0.8 - 1.0
35
+
36
+ # Early stopping threshold for stopping the image generation when changes between steps are minimal.
37
+ # Lower values stop early, saving time, but might produce incomplete images.
38
+ # Higher values take longer but may give more detailed results.
39
+ early_stopping_threshold: 0.01 # Default: 0.01, Suggested range: 0.005 - 0.02
40
+
41
+ # The number of steps between updates of the blend factor.
42
+ # Smaller values update the blend more frequently for smoother transitions.
43
+ # Larger values update the blend less frequently for faster processing.
44
+ update_interval: 10 # Default: 10, Suggested range: 5 - 15
45
+
46
+ # Initial step size, which controls how quickly the image evolves early on.
47
+ # Higher values make big changes at the start, possibly generating faster but less refined images.
48
+ # Lower values make smaller changes, giving more control over details.
49
+ initial_step_size: 0.9 # Default, 0.9, Suggested range: 0.5 - 1.0
50
+
51
+ # Final step size, which controls how much the image changes towards the end.
52
+ # Higher values keep details more flexible until the end, which may add complexity.
53
+ # Lower values lock the details earlier, making the image simpler.
54
+ final_step_size: 0.2 # Default: 0.2, Suggested range: 0.1 - 0.3
55
+
56
+ # Initial noise scaling applied to the image generation process.
57
+ # Higher values add more noise early on, making the initial image more random.
58
+ # Lower values reduce noise early on, leading to a smoother initial image.
59
+ initial_noise_scale: 1.25 # Default, 1.25, Suggested range: 1.0 - 1.5
60
+
61
+ # Final noise scaling applied at the end of the image generation.
62
+ # Higher values add noise towards the end, possibly adding fine detail.
63
+ # Lower values reduce noise towards the end, making the final image smoother.
64
+ final_noise_scale: 0.8 # Default, 0.8, Suggested range: 0.6 - 1.0
65
+
66
+
67
+ smooth_blend_factor: 11 #Default: 11, try 6 for more variation
68
+ step_size_factor: 0.75 #suggested value (0.8) to avoid oversmoothing
69
+ noise_scale_factor: 0.95 #suggested value (0.9) to add more variation
70
+
71
+
72
+ # Enables global randomization.
73
+ # If true, all parameters are randomized within specified min/max ranges.
74
+ # If false, individual parameters with _rand flags set to true will still be randomized.
75
+ randomize: true
76
+
77
+ #Sigma values typically start very small. Lowering this could allow more gradual noise reduction. Too large would overwhelm the process.
78
+ sigma_min_rand: false
79
+ sigma_min_rand_min: 0.001
80
+ sigma_min_rand_max: 0.05
81
+
82
+ #Sigma max controls the upper limit of the noise. A lower minimum could allow faster convergence, while a higher max gives more flexibility for noisier images.
83
+ sigma_max_rand: false
84
+ sigma_max_rand_min: 10
85
+ sigma_max_rand_max: 60
86
+
87
+ #Start blend controls how strongly Karras and Exponential are blended at the start. A slightly lower value introduces more variety in the blending at the beginning.
88
+ start_blend_rand: false
89
+ start_blend_rand_min: 0.05
90
+ start_blend_rand_max: 0.2
91
+
92
+ # End blend affects how much the blending changes towards the end. Increasing the upper limit would allow more variation.
93
+ end_blend_rand: false
94
+ end_blend_rand_min: 0.4
95
+ end_blend_rand_max: 0.6
96
+
97
+ # Sharpness controls detail retention. You wouldn’t want to lower it too much, as it might lose detail.
98
+ sharpness_rand: false
99
+ sharpness_rand_min: 0.85
100
+ sharpness_rand_max: 1.0
101
+
102
+ #A smaller early stopping threshold could lead to earlier stopping if the changes between sigma steps become too small, while the upper value would prevent early stopping until larger changes occur.
103
+ early_stopping_rand: false
104
+ early_stopping_rand_min: 0.001
105
+ early_stopping_rand_max: 0.02
106
+
107
+ #Update intervals affect how frequently blending factors are updated. More frequent updates allow more flexibility in blending.
108
+ update_interval_rand: false
109
+ update_interval_rand_min: 5
110
+ update_interval_rand_max: 10
111
+
112
+ # The initial step size defines how large the steps are at the start. A slightly smaller value introduces more gradual transitions.
113
+ initial_step_rand: false
114
+ initial_step_rand_min: 0.7
115
+ initial_step_rand_max: 1.0
116
+
117
+ # The final step size defines how small the steps become towards the end. A slightly larger range gives more control over the final convergence.
118
+ final_step_rand: false
119
+ final_step_rand_min: 0.1
120
+ final_step_rand_max: 0.3
121
+
122
+ #Initial noise scale defines how much noise to introduce initially. Larger values make the process start with more randomness, while smaller values keep it controlled.
123
+ initial_noise_rand: false
124
+ initial_noise_rand_min: 1.0
125
+ initial_noise_rand_max: 1.5
126
+
127
+ # Final noise scale affects how much noise is reduced at the end. A lower minimum allows more noise to persist, while a higher maximum ensures full convergence.
128
+ final_noise_rand: false
129
+ final_noise_rand_min: 0.6
130
+ final_noise_rand_max: 1.0
131
+
132
+ #The smooth blend factor controls how aggressively the blending is smoothed. Lower values allow more abrupt blending changes, while higher values give smoother transitions.
133
+ smooth_blend_factor_rand: false
134
+ smooth_blend_factor_rand_min: 6
135
+ smooth_blend_factor_rand_max: 11
136
+
137
+ #Step size factor adjusts the step size dynamically to avoid oversmoothing. A lower minimum increases variety, while a higher max provides smoother results.
138
+ step_size_factor_rand: false
139
+ step_size_factor_rand_min: 0.65
140
+ step_size_factor_rand_max: 0.85
141
+
142
+ # Noise scale factor controls how noise is scaled throughout the steps. A slightly lower minimum adds more variety, while keeping the maximum value near the suggested ensures more uniform results.
143
+ noise_scale_factor_rand: false
144
+ noise_scale_factor_rand_min: 0.75
145
+ noise_scale_factor_rand_max: 0.95
146
+
sec/txt2img.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from contextlib import closing
3
+
4
+ import modules.scripts
5
+ from modules import processing, infotext_utils
6
+ from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
7
+ from modules.shared import opts
8
+ import modules.shared as shared
9
+ from modules.ui import plaintext_to_html
10
+ from PIL import Image
11
+ import gradio as gr
12
+
13
+
14
+ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles,
15
+ n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool,
16
+ denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int,
17
+ hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str,
18
+ hr_prompt: str, hr_negative_prompt, override_settings_texts, enable_progressive_growing: bool,
19
+ progressive_growing_min_scale: float, progressive_growing_max_scale: float, progressive_growing_steps: int,
20
+ progressive_growing_refinement: bool, *args, force_enable_hr=False):
21
+ override_settings = create_override_settings_dict(override_settings_texts)
22
+
23
+ if force_enable_hr:
24
+ enable_hr = True
25
+
26
+
27
+ print(f"enable_progressive_growing: {enable_progressive_growing}")
28
+ print(f"progressive_growing_min_scale: {progressive_growing_min_scale}")
29
+
30
+ p = processing.StableDiffusionProcessingTxt2Img(
31
+ sd_model=shared.sd_model,
32
+ outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
33
+ outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
34
+ prompt=prompt,
35
+ styles=prompt_styles,
36
+ negative_prompt=negative_prompt,
37
+ batch_size=batch_size,
38
+ n_iter=n_iter,
39
+ cfg_scale=cfg_scale,
40
+ width=width,
41
+ height=height,
42
+ enable_hr=enable_hr,
43
+ denoising_strength=denoising_strength,
44
+ hr_scale=hr_scale,
45
+ hr_upscaler=hr_upscaler,
46
+ hr_second_pass_steps=hr_second_pass_steps,
47
+ hr_resize_x=hr_resize_x,
48
+ hr_resize_y=hr_resize_y,
49
+ hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
50
+ hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
51
+ hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
52
+ hr_prompt=hr_prompt,
53
+ hr_negative_prompt=hr_negative_prompt,
54
+ override_settings=override_settings,
55
+ )
56
+
57
+ p.id_task = id_task
58
+ p.enable_progressive_growing = enable_progressive_growing
59
+ p.progressive_growing_min_scale = progressive_growing_min_scale
60
+ p.progressive_growing_max_scale = progressive_growing_max_scale
61
+ p.progressive_growing_steps = progressive_growing_steps
62
+ p.progressive_growing_refinement = progressive_growing_refinement
63
+ p.scripts = modules.scripts.scripts_txt2img
64
+ p.script_args = args
65
+
66
+ p.user = request.username
67
+
68
+ if shared.opts.enable_console_prompts:
69
+ print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
70
+
71
+ return p
72
+
73
+
74
+ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
75
+ assert len(gallery) > 0, 'No image to upscale'
76
+ assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
77
+
78
+ p = txt2img_create_processing(id_task, request, *args, force_enable_hr=True)
79
+ p.batch_size = 1
80
+ p.n_iter = 1
81
+ # txt2img_upscale attribute that signifies this is called by txt2img_upscale
82
+ p.txt2img_upscale = True
83
+
84
+ geninfo = json.loads(generation_info)
85
+
86
+ image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
87
+ p.firstpass_image = infotext_utils.image_from_url_text(image_info)
88
+
89
+ parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], [])
90
+ p.seed = parameters.get('Seed', -1)
91
+ p.subseed = parameters.get('Variation seed', -1)
92
+
93
+ p.override_settings['save_images_before_highres_fix'] = False
94
+
95
+ with closing(p):
96
+ processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
97
+
98
+ if processed is None:
99
+ processed = processing.process_images(p)
100
+
101
+ shared.total_tqdm.clear()
102
+
103
+ new_gallery = []
104
+ for i, image in enumerate(gallery):
105
+ if i == gallery_index:
106
+ geninfo["infotexts"][gallery_index: gallery_index+1] = processed.infotexts
107
+ new_gallery.extend(processed.images)
108
+ else:
109
+ fake_image = Image.new(mode="RGB", size=(1, 1))
110
+ fake_image.already_saved_as = image["name"].rsplit('?', 1)[0]
111
+ new_gallery.append(fake_image)
112
+
113
+ geninfo["infotexts"][gallery_index] = processed.info
114
+
115
+ return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
116
+
117
+
118
+ def txt2img(id_task: str, request: gr.Request, *args):
119
+ p = txt2img_create_processing(id_task, request, *args)
120
+
121
+ with closing(p):
122
+ processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
123
+
124
+ if processed is None:
125
+ processed = processing.process_images(p)
126
+
127
+ shared.total_tqdm.clear()
128
+
129
+ generation_info_js = processed.js()
130
+ if opts.samples_log_stdout:
131
+ print(generation_info_js)
132
+
133
+ if opts.do_not_show_images:
134
+ processed.images = []
135
+
136
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
sec/ui.py ADDED
@@ -0,0 +1,1249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import mimetypes
3
+ import os
4
+ import sys
5
+ from functools import reduce
6
+ import warnings
7
+ from contextlib import ExitStack
8
+
9
+ import gradio as gr
10
+ import gradio.utils
11
+ import numpy as np
12
+ from PIL import Image, PngImagePlugin # noqa: F401
13
+ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call, wrap_gradio_call_no_job # noqa: F401
14
+
15
+ from modules import gradio_extensons, sd_schedulers # noqa: F401
16
+ from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
17
+ from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
18
+ from modules.paths import script_path
19
+ from modules.ui_common import create_refresh_button
20
+ from modules.ui_gradio_extensions import reload_javascript
21
+
22
+ from modules.shared import opts, cmd_opts
23
+
24
+ import modules.infotext_utils as parameters_copypaste
25
+ import modules.hypernetworks.ui as hypernetworks_ui
26
+ import modules.textual_inversion.ui as textual_inversion_ui
27
+ import modules.textual_inversion.textual_inversion as textual_inversion
28
+ import modules.shared as shared
29
+ from modules import prompt_parser
30
+ from modules.sd_hijack import model_hijack
31
+ from modules.infotext_utils import image_from_url_text, PasteField
32
+
33
+ create_setting_component = ui_settings.create_setting_component
34
+
35
+ warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
36
+ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
37
+
38
+ # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
39
+ mimetypes.init()
40
+ mimetypes.add_type('application/javascript', '.js')
41
+ mimetypes.add_type('application/javascript', '.mjs')
42
+
43
+ # Likewise, add explicit content-type header for certain missing image types
44
+ mimetypes.add_type('image/webp', '.webp')
45
+ mimetypes.add_type('image/avif', '.avif')
46
+
47
+ if not cmd_opts.share and not cmd_opts.listen:
48
+ # fix gradio phoning home
49
+ gradio.utils.version_check = lambda: None
50
+ gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
51
+
52
+ if cmd_opts.ngrok is not None:
53
+ import modules.ngrok as ngrok
54
+ print('ngrok authtoken detected, trying to connect...')
55
+ ngrok.connect(
56
+ cmd_opts.ngrok,
57
+ cmd_opts.port if cmd_opts.port is not None else 7860,
58
+ cmd_opts.ngrok_options
59
+ )
60
+
61
+
62
+ def gr_show(visible=True):
63
+ return {"visible": visible, "__type__": "update"}
64
+
65
+
66
+ sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
67
+ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
68
+
69
+ # Using constants for these since the variation selector isn't visible.
70
+ # Important that they exactly match script.js for tooltip to work.
71
+ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
72
+ reuse_symbol = '\u267b\ufe0f' # ♻️
73
+ paste_symbol = '\u2199\ufe0f' # ↙
74
+ refresh_symbol = '\U0001f504' # 🔄
75
+ save_style_symbol = '\U0001f4be' # 💾
76
+ apply_style_symbol = '\U0001f4cb' # 📋
77
+ clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
78
+ extra_networks_symbol = '\U0001F3B4' # 🎴
79
+ switch_values_symbol = '\U000021C5' # ⇅
80
+ restore_progress_symbol = '\U0001F300' # 🌀
81
+ detect_image_size_symbol = '\U0001F4D0' # 📐
82
+
83
+
84
+ plaintext_to_html = ui_common.plaintext_to_html
85
+
86
+
87
+ def send_gradio_gallery_to_image(x):
88
+ if len(x) == 0:
89
+ return None
90
+ return image_from_url_text(x[0])
91
+
92
+
93
+ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
94
+ if not enable:
95
+ return ""
96
+
97
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
98
+ p.calculate_target_resolution()
99
+
100
+ return f"from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
101
+
102
+
103
+ def resize_from_to_html(width, height, scale_by):
104
+ target_width = int(width * scale_by)
105
+ target_height = int(height * scale_by)
106
+
107
+ if not target_width or not target_height:
108
+ return "no image selected"
109
+
110
+ return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
111
+
112
+
113
+ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
114
+ if mode in {0, 1, 3, 4}:
115
+ return [interrogation_function(ii_singles[mode]), None]
116
+ elif mode == 2:
117
+ return [interrogation_function(ii_singles[mode]["image"]), None]
118
+ elif mode == 5:
119
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
120
+ images = shared.listfiles(ii_input_dir)
121
+ print(f"Will process {len(images)} images.")
122
+ if ii_output_dir != "":
123
+ os.makedirs(ii_output_dir, exist_ok=True)
124
+ else:
125
+ ii_output_dir = ii_input_dir
126
+
127
+ for image in images:
128
+ img = Image.open(image)
129
+ filename = os.path.basename(image)
130
+ left, _ = os.path.splitext(filename)
131
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
132
+
133
+ return [gr.update(), None]
134
+
135
+
136
+ def interrogate(image):
137
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
138
+ return gr.update() if prompt is None else prompt
139
+
140
+
141
+ def interrogate_deepbooru(image):
142
+ prompt = deepbooru.model.tag(image)
143
+ return gr.update() if prompt is None else prompt
144
+
145
+
146
+ def connect_clear_prompt(button):
147
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
148
+ button.click(
149
+ _js="clear_prompt",
150
+ fn=None,
151
+ inputs=[],
152
+ outputs=[],
153
+ )
154
+
155
+
156
+ def update_token_counter(text, steps, styles, *, is_positive=True):
157
+ params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive)
158
+ script_callbacks.before_token_counter_callback(params)
159
+ text = params.prompt
160
+ steps = params.steps
161
+ styles = params.styles
162
+ is_positive = params.is_positive
163
+
164
+ if shared.opts.include_styles_into_token_counters:
165
+ apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt
166
+ text = apply_styles(text, styles)
167
+
168
+ try:
169
+ text, _ = extra_networks.parse_prompt(text)
170
+
171
+ if is_positive:
172
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
173
+ else:
174
+ prompt_flat_list = [text]
175
+
176
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
177
+
178
+ except Exception:
179
+ # a parsing error can happen here during typing, and we don't want to bother the user with
180
+ # messages related to it in console
181
+ prompt_schedules = [[[steps, text]]]
182
+
183
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
184
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
185
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
186
+ return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
187
+
188
+
189
+ def update_negative_prompt_token_counter(*args):
190
+ return update_token_counter(*args, is_positive=False)
191
+
192
+
193
+ def setup_progressbar(*args, **kwargs):
194
+ pass
195
+
196
+
197
+ def apply_setting(key, value):
198
+ if value is None:
199
+ return gr.update()
200
+
201
+ if shared.cmd_opts.freeze_settings:
202
+ return gr.update()
203
+
204
+ # dont allow model to be swapped when model hash exists in prompt
205
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
206
+ return gr.update()
207
+
208
+ if key == "sd_model_checkpoint":
209
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
210
+
211
+ if ckpt_info is not None:
212
+ value = ckpt_info.title
213
+ else:
214
+ return gr.update()
215
+
216
+ comp_args = opts.data_labels[key].component_args
217
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
218
+ return
219
+
220
+ valtype = type(opts.data_labels[key].default)
221
+ oldval = opts.data.get(key, None)
222
+ opts.data[key] = valtype(value) if valtype != type(None) else value
223
+ if oldval != value and opts.data_labels[key].onchange is not None:
224
+ opts.data_labels[key].onchange()
225
+
226
+ opts.save(shared.config_filename)
227
+ return getattr(opts, key)
228
+
229
+
230
+ def create_output_panel(tabname, outdir, toprow=None):
231
+ return ui_common.create_output_panel(tabname, outdir, toprow)
232
+
233
+
234
+ def ordered_ui_categories():
235
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
236
+
237
+ for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
238
+ yield category
239
+
240
+
241
+ def create_override_settings_dropdown(tabname, row):
242
+ dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
243
+
244
+ dropdown.change(
245
+ fn=lambda x: gr.Dropdown.update(visible=bool(x)),
246
+ inputs=[dropdown],
247
+ outputs=[dropdown],
248
+ )
249
+
250
+ return dropdown
251
+
252
+
253
+ def create_ui():
254
+ import modules.img2img
255
+ import modules.txt2img
256
+
257
+ reload_javascript()
258
+
259
+ parameters_copypaste.reset()
260
+
261
+ settings = ui_settings.UiSettings()
262
+ settings.register_settings()
263
+
264
+ scripts.scripts_current = scripts.scripts_txt2img
265
+ scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
266
+
267
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
268
+ toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
269
+
270
+ dummy_component = gr.Label(visible=False)
271
+
272
+ extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs", elem_classes=["extra-networks"])
273
+ extra_tabs.__enter__()
274
+
275
+ with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
276
+ with ExitStack() as stack:
277
+ if shared.opts.txt2img_settings_accordion:
278
+ stack.enter_context(gr.Accordion("Open for Settings", open=False))
279
+ stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
280
+
281
+ scripts.scripts_txt2img.prepare_ui()
282
+
283
+ for category in ordered_ui_categories():
284
+ if category == "prompt":
285
+ toprow.create_inline_toprow_prompts()
286
+
287
+ elif category == "dimensions":
288
+ with FormRow():
289
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
290
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
291
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
292
+
293
+ with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
294
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height")
295
+
296
+ if opts.dimensions_and_batch_together:
297
+ with gr.Column(elem_id="txt2img_column_batch"):
298
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
299
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
300
+
301
+ elif category == "cfg":
302
+ with gr.Row():
303
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
304
+
305
+ elif category == "checkboxes":
306
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
307
+ pass
308
+
309
+ elif category == "accordions":
310
+ with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
311
+ with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
312
+ with enable_hr.extra():
313
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
314
+
315
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
316
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
317
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
318
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
319
+
320
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
321
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
322
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
323
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
324
+
325
+ with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
326
+
327
+ hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
328
+ create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
329
+
330
+ hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
331
+ hr_scheduler = gr.Dropdown(label='Hires schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
332
+
333
+ with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
334
+ with gr.Column(scale=80):
335
+ with gr.Row():
336
+ hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
337
+ with gr.Column(scale=80):
338
+ with gr.Row():
339
+ hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
340
+
341
+ with InputAccordion(False, label="Progressive Growing", elem_id="txt2img_progressive_growing") as enable_progressive_growing:
342
+ with FormRow(elem_id="txt2img_progressive_growing_row1", variant="compact"):
343
+ progressive_growing_min_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Min Scale", value=0.25, elem_id="txt2img_progressive_growing_min_scale")
344
+ progressive_growing_max_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Max Scale", value=1.0, elem_id="txt2img_progressive_growing_max_scale")
345
+
346
+ with FormRow(elem_id="txt2img_progressive_growing_row2", variant="compact"):
347
+ progressive_growing_steps = gr.Slider(minimum=2, maximum=10, step=1, label="Steps", value=4, elem_id="txt2img_progressive_growing_steps")
348
+ progressive_growing_refinement = gr.Checkbox(label="Enable Refinement", value=True, elem_id="txt2img_progressive_growing_refinement")
349
+
350
+ scripts.scripts_txt2img.setup_ui_for_section(category)
351
+
352
+ elif category == "batch":
353
+ if not opts.dimensions_and_batch_together:
354
+ with FormRow(elem_id="txt2img_column_batch"):
355
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
356
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
357
+
358
+ elif category == "override_settings":
359
+ with FormRow(elem_id="txt2img_override_settings_row") as row:
360
+ override_settings = create_override_settings_dropdown('txt2img', row)
361
+
362
+ elif category == "scripts":
363
+ with FormGroup(elem_id="txt2img_script_container"):
364
+ custom_inputs = scripts.scripts_txt2img.setup_ui()
365
+
366
+ if category not in {"accordions"}:
367
+ scripts.scripts_txt2img.setup_ui_for_section(category)
368
+
369
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
370
+
371
+ for component in hr_resolution_preview_inputs:
372
+ event = component.release if isinstance(component, gr.Slider) else component.change
373
+
374
+ event(
375
+ fn=calc_resolution_hires,
376
+ inputs=hr_resolution_preview_inputs,
377
+ outputs=[hr_final_resolution],
378
+ show_progress=False,
379
+ )
380
+ event(
381
+ None,
382
+ _js="onCalcResolutionHires",
383
+ inputs=hr_resolution_preview_inputs,
384
+ outputs=[],
385
+ show_progress=False,
386
+ )
387
+
388
+ output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
389
+
390
+ txt2img_inputs = [
391
+ dummy_component,
392
+ toprow.prompt,
393
+ toprow.negative_prompt,
394
+ toprow.ui_styles.dropdown,
395
+ batch_count,
396
+ batch_size,
397
+ cfg_scale,
398
+ height,
399
+ width,
400
+ enable_hr,
401
+ denoising_strength,
402
+ hr_scale,
403
+ hr_upscaler,
404
+ hr_second_pass_steps,
405
+ hr_resize_x,
406
+ hr_resize_y,
407
+ hr_checkpoint_name,
408
+ hr_sampler_name,
409
+ hr_scheduler,
410
+ hr_prompt,
411
+ hr_negative_prompt,
412
+ override_settings,
413
+ enable_progressive_growing,
414
+ progressive_growing_min_scale,
415
+ progressive_growing_max_scale,
416
+ progressive_growing_steps,
417
+ progressive_growing_refinement,
418
+ ] + custom_inputs
419
+
420
+ txt2img_outputs = [
421
+ output_panel.gallery,
422
+ output_panel.generation_info,
423
+ output_panel.infotext,
424
+ output_panel.html_log,
425
+ ]
426
+
427
+ txt2img_args = dict(
428
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
429
+ _js="submit",
430
+ inputs=txt2img_inputs,
431
+ outputs=txt2img_outputs,
432
+ show_progress=False,
433
+ )
434
+
435
+ toprow.prompt.submit(**txt2img_args)
436
+ toprow.submit.click(**txt2img_args)
437
+
438
+ output_panel.button_upscale.click(
439
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
440
+ _js="submit_txt2img_upscale",
441
+ inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:],
442
+ outputs=txt2img_outputs,
443
+ show_progress=False,
444
+ )
445
+
446
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
447
+
448
+ toprow.restore_progress_button.click(
449
+ fn=progress.restore_progress,
450
+ _js="restoreProgressTxt2img",
451
+ inputs=[dummy_component],
452
+ outputs=[
453
+ output_panel.gallery,
454
+ output_panel.generation_info,
455
+ output_panel.infotext,
456
+ output_panel.html_log,
457
+ ],
458
+ show_progress=False,
459
+ )
460
+
461
+ txt2img_paste_fields = [
462
+ PasteField(toprow.prompt, "Prompt", api="prompt"),
463
+ PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
464
+ PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
465
+ PasteField(width, "Size-1", api="width"),
466
+ PasteField(height, "Size-2", api="height"),
467
+ PasteField(batch_size, "Batch size", api="batch_size"),
468
+ PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
469
+ PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
470
+ PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
471
+ PasteField(hr_scale, "Hires upscale", api="hr_scale"),
472
+ PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
473
+ PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
474
+ PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
475
+ PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
476
+ PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
477
+ PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
478
+ PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
479
+ PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
480
+ PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
481
+ PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
482
+ PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
483
+ *scripts.scripts_txt2img.infotext_fields
484
+ ]
485
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
486
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
487
+ paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
488
+ ))
489
+
490
+ steps = scripts.scripts_txt2img.script('Sampler').steps
491
+
492
+ txt2img_preview_params = [
493
+ toprow.prompt,
494
+ toprow.negative_prompt,
495
+ steps,
496
+ scripts.scripts_txt2img.script('Sampler').sampler_name,
497
+ cfg_scale,
498
+ scripts.scripts_txt2img.script('Seed').seed,
499
+ width,
500
+ height,
501
+ ]
502
+
503
+ toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
504
+ toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
505
+ toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
506
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
507
+
508
+ extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
509
+ ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
510
+
511
+ extra_tabs.__exit__()
512
+
513
+ scripts.scripts_current = scripts.scripts_img2img
514
+ scripts.scripts_img2img.initialize_scripts(is_img2img=True)
515
+
516
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
517
+ toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
518
+
519
+ extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs", elem_classes=["extra-networks"])
520
+ extra_tabs.__enter__()
521
+
522
+ with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
523
+ with ExitStack() as stack:
524
+ if shared.opts.img2img_settings_accordion:
525
+ stack.enter_context(gr.Accordion("Open for Settings", open=False))
526
+ stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
527
+
528
+ copy_image_buttons = []
529
+ copy_image_destinations = {}
530
+
531
+ def add_copy_image_controls(tab_name, elem):
532
+ with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
533
+ gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
534
+
535
+ for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
536
+ if name == tab_name:
537
+ gr.Button(title, interactive=False)
538
+ copy_image_destinations[name] = elem
539
+ continue
540
+
541
+ button = gr.Button(title)
542
+ copy_image_buttons.append((button, name, elem))
543
+
544
+ scripts.scripts_img2img.prepare_ui()
545
+
546
+ for category in ordered_ui_categories():
547
+ if category == "prompt":
548
+ toprow.create_inline_toprow_prompts()
549
+
550
+ if category == "image":
551
+ with gr.Tabs(elem_id="mode_img2img"):
552
+ img2img_selected_tab = gr.Number(value=0, visible=False)
553
+
554
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
555
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
556
+ add_copy_image_controls('img2img', init_img)
557
+
558
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
559
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
560
+ add_copy_image_controls('sketch', sketch)
561
+
562
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
563
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
564
+ add_copy_image_controls('inpaint', init_img_with_mask)
565
+
566
+ with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
567
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
568
+ inpaint_color_sketch_orig = gr.State(None)
569
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
570
+
571
+ def update_orig(image, state):
572
+ if image is not None:
573
+ same_size = state is not None and state.size == image.size
574
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
575
+ edited = same_size and has_exact_match
576
+ return image if not edited or state is None else state
577
+
578
+ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
579
+
580
+ with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
581
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
582
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
583
+
584
+ with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
585
+ with gr.Tabs(elem_id="img2img_batch_source"):
586
+ img2img_batch_source_type = gr.Textbox(visible=False, value="upload")
587
+ with gr.TabItem('Upload', id='batch_upload', elem_id="img2img_batch_upload_tab") as tab_batch_upload:
588
+ img2img_batch_upload = gr.Files(label="Files", interactive=True, elem_id="img2img_batch_upload")
589
+ with gr.TabItem('From directory', id='batch_from_dir', elem_id="img2img_batch_from_dir_tab") as tab_batch_from_dir:
590
+ hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
591
+ gr.HTML(
592
+ "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
593
+ "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
594
+ f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
595
+ f"{hidden}</p>"
596
+ )
597
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
598
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
599
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
600
+ tab_batch_upload.select(fn=lambda: "upload", inputs=[], outputs=[img2img_batch_source_type])
601
+ tab_batch_from_dir.select(fn=lambda: "from dir", inputs=[], outputs=[img2img_batch_source_type])
602
+ with gr.Accordion("PNG info", open=False):
603
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", elem_id="img2img_batch_use_png_info")
604
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
605
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
606
+
607
+ img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
608
+
609
+ for i, tab in enumerate(img2img_tabs):
610
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
611
+
612
+ def copy_image(img):
613
+ if isinstance(img, dict) and 'image' in img:
614
+ return img['image']
615
+
616
+ return img
617
+
618
+ for button, name, elem in copy_image_buttons:
619
+ button.click(
620
+ fn=copy_image,
621
+ inputs=[elem],
622
+ outputs=[copy_image_destinations[name]],
623
+ )
624
+ button.click(
625
+ fn=lambda: None,
626
+ _js=f"switch_to_{name.replace(' ', '_')}",
627
+ inputs=[],
628
+ outputs=[],
629
+ )
630
+
631
+ with FormRow():
632
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
633
+
634
+ elif category == "dimensions":
635
+ with FormRow():
636
+ with gr.Column(elem_id="img2img_column_size", scale=4):
637
+ selected_scale_tab = gr.Number(value=0, visible=False)
638
+
639
+ with gr.Tabs(elem_id="img2img_tabs_resize"):
640
+ with gr.Tab(label="Resize to", id="to", elem_id="img2img_tab_resize_to") as tab_scale_to:
641
+ with FormRow():
642
+ with gr.Column(elem_id="img2img_column_size", scale=4):
643
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
644
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
645
+ with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
646
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height")
647
+ detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img")
648
+
649
+ with gr.Tab(label="Resize by", id="by", elem_id="img2img_tab_resize_by") as tab_scale_by:
650
+ scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
651
+
652
+ with FormRow():
653
+ scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview")
654
+ gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider")
655
+ button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to")
656
+
657
+ on_change_args = dict(
658
+ fn=resize_from_to_html,
659
+ _js="currentImg2imgSourceResolution",
660
+ inputs=[dummy_component, dummy_component, scale_by],
661
+ outputs=scale_by_html,
662
+ show_progress=False,
663
+ )
664
+
665
+ scale_by.release(**on_change_args)
666
+ button_update_resize_to.click(**on_change_args)
667
+
668
+ tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
669
+ tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
670
+
671
+ if opts.dimensions_and_batch_together:
672
+ with gr.Column(elem_id="img2img_column_batch"):
673
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
674
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
675
+
676
+ elif category == "denoising":
677
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
678
+
679
+ elif category == "cfg":
680
+ with gr.Row():
681
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
682
+ image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
683
+
684
+ elif category == "checkboxes":
685
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
686
+ pass
687
+
688
+ elif category == "accordions":
689
+ with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
690
+ scripts.scripts_img2img.setup_ui_for_section(category)
691
+
692
+ elif category == "batch":
693
+ if not opts.dimensions_and_batch_together:
694
+ with FormRow(elem_id="img2img_column_batch"):
695
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
696
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
697
+
698
+ elif category == "override_settings":
699
+ with FormRow(elem_id="img2img_override_settings_row") as row:
700
+ override_settings = create_override_settings_dropdown('img2img', row)
701
+
702
+ elif category == "scripts":
703
+ with FormGroup(elem_id="img2img_script_container"):
704
+ custom_inputs = scripts.scripts_img2img.setup_ui()
705
+
706
+ elif category == "inpaint":
707
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
708
+ with FormRow():
709
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
710
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
711
+
712
+ with FormRow():
713
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
714
+
715
+ with FormRow():
716
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
717
+
718
+ with FormRow():
719
+ with gr.Column():
720
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
721
+
722
+ with gr.Column(scale=4):
723
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
724
+
725
+ if category not in {"accordions"}:
726
+ scripts.scripts_img2img.setup_ui_for_section(category)
727
+
728
+ # the code below is meant to update the resolution label after the image in the image selection UI has changed.
729
+ # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
730
+ # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
731
+ for component in [init_img, sketch]:
732
+ component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
733
+
734
+ def select_img2img_tab(tab):
735
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
736
+
737
+ for i, elem in enumerate(img2img_tabs):
738
+ elem.select(
739
+ fn=lambda tab=i: select_img2img_tab(tab),
740
+ inputs=[],
741
+ outputs=[inpaint_controls, mask_alpha],
742
+ )
743
+
744
+ output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
745
+
746
+ img2img_args = dict(
747
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
748
+ _js="submit_img2img",
749
+ inputs=[
750
+ dummy_component,
751
+ dummy_component,
752
+ toprow.prompt,
753
+ toprow.negative_prompt,
754
+ toprow.ui_styles.dropdown,
755
+ init_img,
756
+ sketch,
757
+ init_img_with_mask,
758
+ inpaint_color_sketch,
759
+ inpaint_color_sketch_orig,
760
+ init_img_inpaint,
761
+ init_mask_inpaint,
762
+ mask_blur,
763
+ mask_alpha,
764
+ inpainting_fill,
765
+ batch_count,
766
+ batch_size,
767
+ cfg_scale,
768
+ image_cfg_scale,
769
+ denoising_strength,
770
+ selected_scale_tab,
771
+ height,
772
+ width,
773
+ scale_by,
774
+ resize_mode,
775
+ inpaint_full_res,
776
+ inpaint_full_res_padding,
777
+ inpainting_mask_invert,
778
+ img2img_batch_input_dir,
779
+ img2img_batch_output_dir,
780
+ img2img_batch_inpaint_mask_dir,
781
+ override_settings,
782
+ img2img_batch_use_png_info,
783
+ img2img_batch_png_info_props,
784
+ img2img_batch_png_info_dir,
785
+ img2img_batch_source_type,
786
+ img2img_batch_upload,
787
+ ] + custom_inputs,
788
+ outputs=[
789
+ output_panel.gallery,
790
+ output_panel.generation_info,
791
+ output_panel.infotext,
792
+ output_panel.html_log,
793
+ ],
794
+ show_progress=False,
795
+ )
796
+
797
+ interrogate_args = dict(
798
+ _js="get_img2img_tab_index",
799
+ inputs=[
800
+ dummy_component,
801
+ img2img_batch_input_dir,
802
+ img2img_batch_output_dir,
803
+ init_img,
804
+ sketch,
805
+ init_img_with_mask,
806
+ inpaint_color_sketch,
807
+ init_img_inpaint,
808
+ ],
809
+ outputs=[toprow.prompt, dummy_component],
810
+ )
811
+
812
+ toprow.prompt.submit(**img2img_args)
813
+ toprow.submit.click(**img2img_args)
814
+
815
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
816
+
817
+ detect_image_size_btn.click(
818
+ fn=lambda w, h, _: (w or gr.update(), h or gr.update()),
819
+ _js="currentImg2imgSourceResolution",
820
+ inputs=[dummy_component, dummy_component, dummy_component],
821
+ outputs=[width, height],
822
+ show_progress=False,
823
+ )
824
+
825
+ toprow.restore_progress_button.click(
826
+ fn=progress.restore_progress,
827
+ _js="restoreProgressImg2img",
828
+ inputs=[dummy_component],
829
+ outputs=[
830
+ output_panel.gallery,
831
+ output_panel.generation_info,
832
+ output_panel.infotext,
833
+ output_panel.html_log,
834
+ ],
835
+ show_progress=False,
836
+ )
837
+
838
+ toprow.button_interrogate.click(
839
+ fn=lambda *args: process_interrogate(interrogate, *args),
840
+ **interrogate_args,
841
+ )
842
+
843
+ toprow.button_deepbooru.click(
844
+ fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
845
+ **interrogate_args,
846
+ )
847
+
848
+ steps = scripts.scripts_img2img.script('Sampler').steps
849
+
850
+ toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
851
+ toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
852
+ toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
853
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
854
+
855
+ img2img_paste_fields = [
856
+ (toprow.prompt, "Prompt"),
857
+ (toprow.negative_prompt, "Negative prompt"),
858
+ (cfg_scale, "CFG scale"),
859
+ (image_cfg_scale, "Image CFG scale"),
860
+ (width, "Size-1"),
861
+ (height, "Size-2"),
862
+ (batch_size, "Batch size"),
863
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
864
+ (denoising_strength, "Denoising strength"),
865
+ (mask_blur, "Mask blur"),
866
+ (inpainting_mask_invert, 'Mask mode'),
867
+ (inpainting_fill, 'Masked content'),
868
+ (inpaint_full_res, 'Inpaint area'),
869
+ (inpaint_full_res_padding, 'Masked area padding'),
870
+ *scripts.scripts_img2img.infotext_fields
871
+ ]
872
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
873
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
874
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
875
+ paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
876
+ ))
877
+
878
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
879
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery)
880
+
881
+ extra_tabs.__exit__()
882
+
883
+ scripts.scripts_current = None
884
+
885
+ with gr.Blocks(analytics_enabled=False) as extras_interface:
886
+ ui_postprocessing.create_ui()
887
+
888
+ with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
889
+ with ResizeHandleRow(equal_height=False):
890
+ with gr.Column(variant='panel'):
891
+ image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
892
+
893
+ with gr.Column(variant='panel'):
894
+ html = gr.HTML()
895
+ generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
896
+ html2 = gr.HTML()
897
+ with gr.Row():
898
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
899
+
900
+ for tabname, button in buttons.items():
901
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
902
+ paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
903
+ ))
904
+
905
+ image.change(
906
+ fn=wrap_gradio_call_no_job(modules.extras.run_pnginfo),
907
+ inputs=[image],
908
+ outputs=[html, generation_info, html2],
909
+ )
910
+
911
+ modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
912
+
913
+ with gr.Blocks(analytics_enabled=False) as train_interface:
914
+ with gr.Row(equal_height=False):
915
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
916
+
917
+ with ResizeHandleRow(variant="compact", equal_height=False):
918
+ with gr.Tabs(elem_id="train_tabs"):
919
+
920
+ with gr.Tab(label="Create embedding", id="create_embedding"):
921
+ new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
922
+ initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
923
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
924
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
925
+
926
+ with gr.Row():
927
+ with gr.Column(scale=3):
928
+ gr.HTML(value="")
929
+
930
+ with gr.Column():
931
+ create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
932
+
933
+ with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
934
+ new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
935
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
936
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
937
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
938
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
939
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
940
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
941
+ new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
942
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
943
+
944
+ with gr.Row():
945
+ with gr.Column(scale=3):
946
+ gr.HTML(value="")
947
+
948
+ with gr.Column():
949
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
950
+
951
+ def get_textual_inversion_template_names():
952
+ return sorted(textual_inversion.textual_inversion_templates)
953
+
954
+ with gr.Tab(label="Train", id="train"):
955
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
956
+ with FormRow():
957
+ train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
958
+ create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
959
+
960
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
961
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
962
+
963
+ with FormRow():
964
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
965
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
966
+
967
+ with FormRow():
968
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
969
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
970
+
971
+ with FormRow():
972
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
973
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
974
+
975
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
976
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
977
+
978
+ with FormRow():
979
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
980
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
981
+
982
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
983
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
984
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
985
+ steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
986
+
987
+ with FormRow():
988
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
989
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
990
+
991
+ use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")
992
+
993
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
994
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
995
+
996
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
997
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
998
+
999
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
1000
+
1001
+ with gr.Row():
1002
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
1003
+ interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
1004
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
1005
+
1006
+ params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
1007
+
1008
+ script_callbacks.ui_train_tabs_callback(params)
1009
+
1010
+ with gr.Column(elem_id='ti_gallery_container'):
1011
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
1012
+ gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
1013
+ gr.HTML(elem_id="ti_progress", value="")
1014
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
1015
+
1016
+ create_embedding.click(
1017
+ fn=textual_inversion_ui.create_embedding,
1018
+ inputs=[
1019
+ new_embedding_name,
1020
+ initialization_text,
1021
+ nvpt,
1022
+ overwrite_old_embedding,
1023
+ ],
1024
+ outputs=[
1025
+ train_embedding_name,
1026
+ ti_output,
1027
+ ti_outcome,
1028
+ ]
1029
+ )
1030
+
1031
+ create_hypernetwork.click(
1032
+ fn=hypernetworks_ui.create_hypernetwork,
1033
+ inputs=[
1034
+ new_hypernetwork_name,
1035
+ new_hypernetwork_sizes,
1036
+ overwrite_old_hypernetwork,
1037
+ new_hypernetwork_layer_structure,
1038
+ new_hypernetwork_activation_func,
1039
+ new_hypernetwork_initialization_option,
1040
+ new_hypernetwork_add_layer_norm,
1041
+ new_hypernetwork_use_dropout,
1042
+ new_hypernetwork_dropout_structure
1043
+ ],
1044
+ outputs=[
1045
+ train_hypernetwork_name,
1046
+ ti_output,
1047
+ ti_outcome,
1048
+ ]
1049
+ )
1050
+
1051
+ train_embedding.click(
1052
+ fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
1053
+ _js="start_training_textual_inversion",
1054
+ inputs=[
1055
+ dummy_component,
1056
+ train_embedding_name,
1057
+ embedding_learn_rate,
1058
+ batch_size,
1059
+ gradient_step,
1060
+ dataset_directory,
1061
+ log_directory,
1062
+ training_width,
1063
+ training_height,
1064
+ varsize,
1065
+ steps,
1066
+ clip_grad_mode,
1067
+ clip_grad_value,
1068
+ shuffle_tags,
1069
+ tag_drop_out,
1070
+ latent_sampling_method,
1071
+ use_weight,
1072
+ create_image_every,
1073
+ save_embedding_every,
1074
+ template_file,
1075
+ save_image_with_stored_embedding,
1076
+ preview_from_txt2img,
1077
+ *txt2img_preview_params,
1078
+ ],
1079
+ outputs=[
1080
+ ti_output,
1081
+ ti_outcome,
1082
+ ]
1083
+ )
1084
+
1085
+ train_hypernetwork.click(
1086
+ fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
1087
+ _js="start_training_textual_inversion",
1088
+ inputs=[
1089
+ dummy_component,
1090
+ train_hypernetwork_name,
1091
+ hypernetwork_learn_rate,
1092
+ batch_size,
1093
+ gradient_step,
1094
+ dataset_directory,
1095
+ log_directory,
1096
+ training_width,
1097
+ training_height,
1098
+ varsize,
1099
+ steps,
1100
+ clip_grad_mode,
1101
+ clip_grad_value,
1102
+ shuffle_tags,
1103
+ tag_drop_out,
1104
+ latent_sampling_method,
1105
+ use_weight,
1106
+ create_image_every,
1107
+ save_embedding_every,
1108
+ template_file,
1109
+ preview_from_txt2img,
1110
+ *txt2img_preview_params,
1111
+ ],
1112
+ outputs=[
1113
+ ti_output,
1114
+ ti_outcome,
1115
+ ]
1116
+ )
1117
+
1118
+ interrupt_training.click(
1119
+ fn=lambda: shared.state.interrupt(),
1120
+ inputs=[],
1121
+ outputs=[],
1122
+ )
1123
+
1124
+ loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
1125
+ ui_settings_from_file = loadsave.ui_settings.copy()
1126
+
1127
+ settings.create_ui(loadsave, dummy_component)
1128
+
1129
+ interfaces = [
1130
+ (txt2img_interface, "txt2img", "txt2img"),
1131
+ (img2img_interface, "img2img", "img2img"),
1132
+ (extras_interface, "Extras", "extras"),
1133
+ (pnginfo_interface, "PNG Info", "pnginfo"),
1134
+ (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
1135
+ (train_interface, "Train", "train"),
1136
+ ]
1137
+
1138
+ interfaces += script_callbacks.ui_tabs_callback()
1139
+ interfaces += [(settings.interface, "Settings", "settings")]
1140
+
1141
+ extensions_interface = ui_extensions.create_ui()
1142
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
1143
+
1144
+ shared.tab_names = []
1145
+ for _interface, label, _ifid in interfaces:
1146
+ shared.tab_names.append(label)
1147
+
1148
+ with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
1149
+ settings.add_quicksettings()
1150
+
1151
+ parameters_copypaste.connect_paste_params_buttons()
1152
+
1153
+ with gr.Tabs(elem_id="tabs") as tabs:
1154
+ tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
1155
+ sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))
1156
+
1157
+ for interface, label, ifid in sorted_interfaces:
1158
+ if label in shared.opts.hidden_tabs:
1159
+ continue
1160
+ with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
1161
+ interface.render()
1162
+
1163
+ if ifid not in ["extensions", "settings"]:
1164
+ loadsave.add_block(interface, ifid)
1165
+
1166
+ loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
1167
+
1168
+ loadsave.setup_ui()
1169
+
1170
+ if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
1171
+ gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
1172
+
1173
+ footer = shared.html("footer.html")
1174
+ footer = footer.format(versions=versions_html(), api_docs="/docs" if shared.cmd_opts.api else "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API")
1175
+ gr.HTML(footer, elem_id="footer")
1176
+
1177
+ settings.add_functionality(demo)
1178
+
1179
+ update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
1180
+ settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
1181
+ demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
1182
+
1183
+ modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
1184
+
1185
+ if ui_settings_from_file != loadsave.ui_settings:
1186
+ loadsave.dump_defaults()
1187
+ demo.ui_loadsave = loadsave
1188
+
1189
+ return demo
1190
+
1191
+
1192
+ def versions_html():
1193
+ import torch
1194
+ import launch
1195
+
1196
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
1197
+ commit = launch.commit_hash()
1198
+ tag = launch.git_tag()
1199
+
1200
+ if shared.xformers_available:
1201
+ import xformers
1202
+ xformers_version = xformers.__version__
1203
+ else:
1204
+ xformers_version = "N/A"
1205
+
1206
+ return f"""
1207
+ version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
1208
+ &#x2000;•&#x2000;
1209
+ python: <span title="{sys.version}">{python_version}</span>
1210
+ &#x2000;•&#x2000;
1211
+ torch: {getattr(torch, '__long_version__',torch.__version__)}
1212
+ &#x2000;•&#x2000;
1213
+ xformers: {xformers_version}
1214
+ &#x2000;•&#x2000;
1215
+ gradio: {gr.__version__}
1216
+ &#x2000;•&#x2000;
1217
+ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
1218
+ """
1219
+
1220
+
1221
+ def setup_ui_api(app):
1222
+ from pydantic import BaseModel, Field
1223
+
1224
+ class QuicksettingsHint(BaseModel):
1225
+ name: str = Field(title="Name of the quicksettings field")
1226
+ label: str = Field(title="Label of the quicksettings field")
1227
+
1228
+ def quicksettings_hint():
1229
+ return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
1230
+
1231
+ app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
1232
+
1233
+ app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
1234
+
1235
+ app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])
1236
+
1237
+ def download_sysinfo(attachment=False):
1238
+ from fastapi.responses import PlainTextResponse
1239
+
1240
+ text = sysinfo.get()
1241
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
1242
+
1243
+ return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
1244
+
1245
+ app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
1246
+ app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
1247
+
1248
+ import fastapi.staticfiles
1249
+ app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets")