SynLayers commited on
Commit
6e99a69
·
verified ·
1 Parent(s): 2d767a5

Upload models/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/pipeline.py +821 -0
models/pipeline.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+ import einops
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms as T
8
+
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.utils import is_torch_xla_available, logging
11
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
12
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxPipeline
13
+
14
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
15
+ from models.multiLayer_adapter import MultiLayerAdapter
16
+
17
+ from PIL import Image
18
+
19
+ if is_torch_xla_available():
20
+ import torch_xla.core.xla_model as xm # type: ignore
21
+ XLA_AVAILABLE = True
22
+ else:
23
+ XLA_AVAILABLE = False
24
+
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+ class CustomFluxPipeline(FluxPipeline):
29
+
30
+ @staticmethod
31
+ def _prepare_latent_image_ids(height, width, list_layer_box, device, dtype):
32
+
33
+ latent_image_ids_list = []
34
+ for layer_idx in range(len(list_layer_box)):
35
+ if list_layer_box[layer_idx] == None:
36
+ continue
37
+ else:
38
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3) # [h/2, w/2, 3]
39
+ latent_image_ids[..., 0] = layer_idx # use the first dimension for layer representation
40
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
41
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
42
+
43
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
44
+ x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
45
+ latent_image_ids = latent_image_ids[y1:y2, x1:x2, :]
46
+
47
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
48
+ latent_image_ids = latent_image_ids.reshape(
49
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
50
+ )
51
+
52
+ latent_image_ids_list.append(latent_image_ids)
53
+
54
+ full_latent_image_ids = torch.cat(latent_image_ids_list, dim=0)
55
+
56
+ return full_latent_image_ids.to(device=device, dtype=dtype)
57
+
58
+ def prepare_latents(
59
+ self,
60
+ batch_size,
61
+ num_layers,
62
+ num_channels_latents,
63
+ height,
64
+ width,
65
+ list_layer_box,
66
+ dtype,
67
+ device,
68
+ generator,
69
+ latents=None,
70
+ ):
71
+ height = 2 * (int(height) // self.vae_scale_factor) # Here, the vae_scale_factor is 16, but the actual latent size is height // 8, so we need to multiply by 2.
72
+ width = 2 * (int(width) // self.vae_scale_factor)
73
+
74
+ shape = (batch_size, num_layers, num_channels_latents, height, width) # (1, 15, 16, 64, 64)
75
+
76
+ if latents is not None:
77
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
78
+ return latents.to(device=device, dtype=dtype), latent_image_ids
79
+
80
+ if isinstance(generator, list) and len(generator) != batch_size:
81
+ raise ValueError(
82
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
83
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
84
+ )
85
+
86
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # [bs, f, c_latent, h, w]
87
+
88
+ latent_image_ids = self._prepare_latent_image_ids(height, width, list_layer_box, device, dtype)
89
+
90
+ return latents, latent_image_ids
91
+
92
+ def prepare_image(
93
+ self,
94
+ image,
95
+ width,
96
+ height,
97
+ batch_size,
98
+ num_images_per_prompt,
99
+ device,
100
+ dtype,
101
+ do_classifier_free_guidance=False,
102
+ ):
103
+ # Prepare image
104
+ if isinstance(image, torch.Tensor):
105
+ pass
106
+ else:
107
+ image = self.image_processor.preprocess(image, height=height, width=width)
108
+
109
+ image_batch_size = image.shape[0]
110
+ if image_batch_size == 1:
111
+ repeat_by = batch_size
112
+ else:
113
+ # image batch size is the same as prompt batch size
114
+ repeat_by = num_images_per_prompt
115
+ image = image.repeat_interleave(repeat_by, dim=0)
116
+ image = image.to(device=device, dtype=dtype) # (1, C, H, W)
117
+
118
+ # create blank mask
119
+ mask = Image.new("RGB", (width, height), (0, 0, 0)) # Currently, the mask is not being used in practice.
120
+
121
+ # Prepare mask
122
+ if isinstance(mask, torch.Tensor):
123
+ pass
124
+ else:
125
+ self.mask_processor = VaeImageProcessor(
126
+ vae_scale_factor=self.vae_scale_factor,
127
+ do_resize=True,
128
+ do_convert_grayscale=True,
129
+ do_normalize=False,
130
+ do_binarize=True,
131
+ )
132
+ mask = self.mask_processor.preprocess(mask, height=height, width=width)
133
+ mask = mask.repeat_interleave(repeat_by, dim=0)
134
+ mask = mask.to(device=device, dtype=dtype) # (1, 1, H, W)
135
+
136
+ # Get masked image
137
+ masked_image = image.clone()
138
+ masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 # (1, 3, H, W)
139
+
140
+ # Encode to latents
141
+ image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
142
+ image_latents = (
143
+ image_latents - self.vae.config.shift_factor
144
+ ) * self.vae.config.scaling_factor
145
+ image_latents = image_latents.to(dtype) # (1, 16, H/8, W/8)
146
+
147
+ mask = torch.nn.functional.interpolate(
148
+ mask, size=(height // self.vae_scale_factor * 2, width // self.vae_scale_factor * 2)
149
+ )
150
+ mask = 1 - mask # (1, 1, H/8, W/8)
151
+
152
+ adapter_image = torch.cat([image_latents, mask], dim=1)
153
+
154
+ # Pack cond latents
155
+ packed_adapter_image = self._pack_latents(
156
+ adapter_image,
157
+ batch_size * num_images_per_prompt,
158
+ adapter_image.shape[1],
159
+ adapter_image.shape[2],
160
+ adapter_image.shape[3],
161
+ )
162
+
163
+ if do_classifier_free_guidance:
164
+ packed_adapter_image = torch.cat([packed_adapter_image] * 2)
165
+
166
+ return packed_adapter_image, height, width
167
+
168
+ def set_multiLayerAdapter(self, multiLayerAdapter):
169
+ self.multiLayerAdapter = multiLayerAdapter
170
+
171
+ @torch.no_grad()
172
+ def __call__(
173
+ self,
174
+ prompt: Union[str, List[str]] = None,
175
+ prompt_2: Optional[Union[str, List[str]]] = None,
176
+ validation_box: List[tuple] = None,
177
+ height: Optional[int] = None,
178
+ width: Optional[int] = None,
179
+ num_inference_steps: int = 28,
180
+ timesteps: List[int] = None,
181
+ guidance_scale: float = 3.5,
182
+ num_images_per_prompt: Optional[int] = 1,
183
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
184
+ latents: Optional[torch.FloatTensor] = None,
185
+ prompt_embeds: Optional[torch.FloatTensor] = None,
186
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
187
+ output_type: Optional[str] = "pil",
188
+ return_dict: bool = True,
189
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
190
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
191
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
192
+ max_sequence_length: int = 512,
193
+ num_layers: int = 5,
194
+ sdxl_vae: nn.Module = None,
195
+ transparent_decoder: nn.Module = None,
196
+ ):
197
+ r"""
198
+ Function invoked when calling the pipeline for generation.
199
+
200
+ Args:
201
+ prompt (`str` or `List[str]`, *optional*):
202
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
203
+ instead.
204
+ prompt_2 (`str` or `List[str]`, *optional*):
205
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
206
+ will be used instead
207
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
208
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
209
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
210
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
211
+ num_inference_steps (`int`, *optional*, defaults to 50):
212
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
213
+ expense of slower inference.
214
+ timesteps (`List[int]`, *optional*):
215
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
216
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
217
+ passed will be used. Must be in descending order.
218
+ guidance_scale (`float`, *optional*, defaults to 7.0):
219
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
220
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
221
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
222
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
223
+ usually at the expense of lower image quality.
224
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
225
+ The number of images to generate per prompt.
226
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
227
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
228
+ to make generation deterministic.
229
+ latents (`torch.FloatTensor`, *optional*):
230
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
231
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
232
+ tensor will ge generated by sampling using the supplied random `generator`.
233
+ prompt_embeds (`torch.FloatTensor`, *optional*):
234
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
235
+ provided, text embeddings will be generated from `prompt` input argument.
236
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
237
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
238
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
239
+ output_type (`str`, *optional*, defaults to `"pil"`):
240
+ The output format of the generate image. Choose between
241
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
242
+ return_dict (`bool`, *optional*, defaults to `True`):
243
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
244
+ joint_attention_kwargs (`dict`, *optional*):
245
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
246
+ `self.processor` in
247
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
248
+ callback_on_step_end (`Callable`, *optional*):
249
+ A function that calls at the end of each denoising steps during the inference. The function is called
250
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
251
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
252
+ `callback_on_step_end_tensor_inputs`.
253
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
254
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
255
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
256
+ `._callback_tensor_inputs` attribute of your pipeline class.
257
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
258
+
259
+ Examples:
260
+
261
+ Returns:
262
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
263
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
264
+ images.
265
+ """
266
+
267
+ height = height or self.default_sample_size * self.vae_scale_factor
268
+ width = width or self.default_sample_size * self.vae_scale_factor
269
+
270
+ # 1. Check inputs. Raise error if not correct
271
+ self.check_inputs(
272
+ prompt,
273
+ prompt_2,
274
+ height,
275
+ width,
276
+ prompt_embeds=prompt_embeds,
277
+ pooled_prompt_embeds=pooled_prompt_embeds,
278
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
279
+ max_sequence_length=max_sequence_length,
280
+ )
281
+
282
+ self._guidance_scale = guidance_scale
283
+ self._joint_attention_kwargs = joint_attention_kwargs
284
+ self._interrupt = False
285
+
286
+ # 2. Define call parameters
287
+ if prompt is not None and isinstance(prompt, str):
288
+ batch_size = 1
289
+ elif prompt is not None and isinstance(prompt, list):
290
+ batch_size = len(prompt)
291
+ else:
292
+ batch_size = prompt_embeds.shape[0]
293
+
294
+ device = self._execution_device
295
+
296
+ lora_scale = (
297
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
298
+ )
299
+ (
300
+ prompt_embeds,
301
+ pooled_prompt_embeds,
302
+ text_ids,
303
+ ) = self.encode_prompt(
304
+ prompt=prompt,
305
+ prompt_2=prompt_2,
306
+ prompt_embeds=prompt_embeds,
307
+ pooled_prompt_embeds=pooled_prompt_embeds,
308
+ device=device,
309
+ num_images_per_prompt=num_images_per_prompt,
310
+ max_sequence_length=max_sequence_length,
311
+ lora_scale=lora_scale,
312
+ )
313
+
314
+ # 4. Prepare latent variables
315
+ num_channels_latents = self.transformer.config.in_channels // 4
316
+ latents, latent_image_ids = self.prepare_latents(
317
+ batch_size * num_images_per_prompt,
318
+ num_layers,
319
+ num_channels_latents,
320
+ height,
321
+ width,
322
+ validation_box,
323
+ prompt_embeds.dtype,
324
+ device,
325
+ generator,
326
+ latents,
327
+ )
328
+
329
+ # 5. Prepare timesteps
330
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
331
+ image_seq_len = latent_image_ids.shape[0] # ???
332
+ mu = calculate_shift(
333
+ image_seq_len,
334
+ self.scheduler.config.base_image_seq_len,
335
+ self.scheduler.config.max_image_seq_len,
336
+ self.scheduler.config.base_shift,
337
+ self.scheduler.config.max_shift,
338
+ )
339
+ timesteps, num_inference_steps = retrieve_timesteps(
340
+ self.scheduler,
341
+ num_inference_steps,
342
+ device,
343
+ timesteps,
344
+ sigmas,
345
+ mu=mu,
346
+ )
347
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
348
+ self._num_timesteps = len(timesteps)
349
+
350
+ # handle guidance
351
+ if self.transformer.config.guidance_embeds:
352
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
353
+ guidance = guidance.expand(latents.shape[0])
354
+ else:
355
+ guidance = None
356
+
357
+ # 6. Denoising loop
358
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
359
+ for i, t in enumerate(timesteps):
360
+ if self.interrupt:
361
+ continue
362
+
363
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
364
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
365
+
366
+ noise_pred = self.transformer(
367
+ hidden_states=latents,
368
+ list_layer_box=validation_box,
369
+ timestep=timestep / 1000,
370
+ guidance=guidance,
371
+ pooled_projections=pooled_prompt_embeds,
372
+ encoder_hidden_states=prompt_embeds,
373
+ txt_ids=text_ids,
374
+ img_ids=latent_image_ids,
375
+ joint_attention_kwargs=self.joint_attention_kwargs,
376
+ return_dict=False,
377
+ )[0]
378
+
379
+ # compute the previous noisy sample x_t -> x_t-1
380
+ latents_dtype = latents.dtype
381
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
382
+
383
+ if latents.dtype != latents_dtype:
384
+ if torch.backends.mps.is_available():
385
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
386
+ latents = latents.to(latents_dtype)
387
+
388
+ if callback_on_step_end is not None:
389
+ callback_kwargs = {}
390
+ for k in callback_on_step_end_tensor_inputs:
391
+ callback_kwargs[k] = locals()[k]
392
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
393
+
394
+ latents = callback_outputs.pop("latents", latents)
395
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
396
+
397
+ # call the callback, if provided
398
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
399
+ progress_bar.update()
400
+
401
+ if XLA_AVAILABLE:
402
+ xm.mark_step()
403
+
404
+ # create a grey latent
405
+ bs, n_frames, channel_latent, height, width = latents.shape
406
+
407
+ pixel_grey = torch.zeros(size=(bs*n_frames, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
408
+ latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
409
+ latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
410
+ latent_grey = latent_grey.view(bs, n_frames, channel_latent, height, width) # [bs, f, c_latent, h, w]
411
+
412
+ # fill in the latents
413
+ for layer_idx in range(latent_grey.shape[1]):
414
+ x1, y1, x2, y2 = validation_box[layer_idx]
415
+ x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
416
+ latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
417
+ latents = latent_grey
418
+
419
+ if output_type == "latent":
420
+ image = latents
421
+
422
+ else:
423
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
424
+ latents = latents.reshape(bs * n_frames, channel_latent, height, width)
425
+ image = self.vae.decode(latents, return_dict=False)[0]
426
+ if sdxl_vae is not None:
427
+ sdxl_vae = sdxl_vae.to(dtype=image.dtype, device=image.device)
428
+ sdxl_latents = sdxl_vae.encode(image).latent_dist.sample()
429
+ transparent_decoder = transparent_decoder.to(dtype=image.dtype, device=image.device)
430
+ result_list, vis_list = transparent_decoder(sdxl_vae, sdxl_latents)
431
+ else:
432
+ result_list, vis_list = None, None
433
+ image = self.image_processor.postprocess(image, output_type=output_type)
434
+
435
+ # Offload all models
436
+ self.maybe_free_model_hooks()
437
+
438
+ if not return_dict:
439
+ return (image, result_list, vis_list)
440
+
441
+ return FluxPipelineOutput(images=image), result_list, vis_list
442
+
443
+
444
+ class CustomFluxPipelineCfgLayer(CustomFluxPipeline):
445
+
446
+ @torch.no_grad()
447
+ def __call__(
448
+ self,
449
+ prompt: Union[str, List[str]] = None,
450
+ prompt_2: Optional[Union[str, List[str]]] = None,
451
+ validation_box: List[tuple] = None,
452
+ height: Optional[int] = None,
453
+ width: Optional[int] = None,
454
+ num_inference_steps: int = 28,
455
+ timesteps: List[int] = None,
456
+ guidance_scale: float = 3.5,
457
+ true_gs: float = 3.5,
458
+ adapter_image: PipelineImageInput = None,
459
+ adapter_mask: PipelineImageInput = None,
460
+ adapter_conditioning_scale: Union[float, List[float]] = 1.0,
461
+ num_images_per_prompt: Optional[int] = 1,
462
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
463
+ latents: Optional[torch.FloatTensor] = None,
464
+ prompt_embeds: Optional[torch.FloatTensor] = None,
465
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
466
+ output_type: Optional[str] = "pil",
467
+ return_dict: bool = True,
468
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
469
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
470
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
471
+ max_sequence_length: int = 512,
472
+ num_layers: int = 5,
473
+ sdxl_vae: nn.Module = None,
474
+ transparent_decoder: nn.Module = None,
475
+ ):
476
+ r"""
477
+ Function invoked when calling the pipeline for generation.
478
+
479
+ Args:
480
+ prompt (`str` or `List[str]`, *optional*):
481
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
482
+ instead.
483
+ prompt_2 (`str` or `List[str]`, *optional*):
484
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
485
+ will be used instead
486
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
487
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
488
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
489
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
490
+ num_inference_steps (`int`, *optional*, defaults to 50):
491
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
492
+ expense of slower inference.
493
+ timesteps (`List[int]`, *optional*):
494
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
495
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
496
+ passed will be used. Must be in descending order.
497
+ guidance_scale (`float`, *optional*, defaults to 7.0):
498
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
499
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
500
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
501
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
502
+ usually at the expense of lower image quality.
503
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
504
+ The number of images to generate per prompt.
505
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
506
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
507
+ to make generation deterministic.
508
+ latents (`torch.FloatTensor`, *optional*):
509
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
510
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
511
+ tensor will ge generated by sampling using the supplied random `generator`.
512
+ prompt_embeds (`torch.FloatTensor`, *optional*):
513
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
514
+ provided, text embeddings will be generated from `prompt` input argument.
515
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
516
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
517
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
518
+ output_type (`str`, *optional*, defaults to `"pil"`):
519
+ The output format of the generate image. Choose between
520
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
521
+ return_dict (`bool`, *optional*, defaults to `True`):
522
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
523
+ joint_attention_kwargs (`dict`, *optional*):
524
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
525
+ `self.processor` in
526
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
527
+ callback_on_step_end (`Callable`, *optional*):
528
+ A function that calls at the end of each denoising steps during the inference. The function is called
529
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
530
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
531
+ `callback_on_step_end_tensor_inputs`.
532
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
533
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
534
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
535
+ `._callback_tensor_inputs` attribute of your pipeline class.
536
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
537
+
538
+ Examples:
539
+
540
+ Returns:
541
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
542
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
543
+ images.
544
+ """
545
+
546
+ height = height or self.default_sample_size * self.vae_scale_factor
547
+ width = width or self.default_sample_size * self.vae_scale_factor
548
+
549
+ # 1. Check inputs. Raise error if not correct
550
+ self.check_inputs(
551
+ prompt,
552
+ prompt_2,
553
+ height,
554
+ width,
555
+ prompt_embeds=prompt_embeds,
556
+ pooled_prompt_embeds=pooled_prompt_embeds,
557
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
558
+ max_sequence_length=max_sequence_length,
559
+ )
560
+
561
+ self._guidance_scale = guidance_scale
562
+ self._joint_attention_kwargs = joint_attention_kwargs
563
+ self._interrupt = False
564
+
565
+ # 2. Define call parameters
566
+ if prompt is not None and isinstance(prompt, str):
567
+ batch_size = 1
568
+ elif prompt is not None and isinstance(prompt, list):
569
+ batch_size = len(prompt)
570
+ else:
571
+ batch_size = prompt_embeds.shape[0]
572
+
573
+ device = self._execution_device
574
+
575
+ lora_scale = (
576
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
577
+ )
578
+ (
579
+ prompt_embeds,
580
+ pooled_prompt_embeds,
581
+ text_ids,
582
+ ) = self.encode_prompt(
583
+ prompt=prompt,
584
+ prompt_2=prompt_2,
585
+ prompt_embeds=prompt_embeds,
586
+ pooled_prompt_embeds=pooled_prompt_embeds,
587
+ device=device,
588
+ num_images_per_prompt=num_images_per_prompt,
589
+ max_sequence_length=max_sequence_length,
590
+ lora_scale=lora_scale,
591
+ )
592
+ (
593
+ neg_prompt_embeds,
594
+ neg_pooled_prompt_embeds,
595
+ neg_text_ids,
596
+ ) = self.encode_prompt(
597
+ prompt="",
598
+ prompt_2=None,
599
+ device=device,
600
+ num_images_per_prompt=num_images_per_prompt,
601
+ max_sequence_length=max_sequence_length,
602
+ lora_scale=lora_scale,
603
+ )
604
+
605
+ # 3. Prepare image
606
+ num_channels_latents = self.transformer.config.in_channels // 4
607
+ if isinstance(self.multiLayerAdapter, MultiLayerAdapter):
608
+ adapter_image, _, _ = self.prepare_image(
609
+ image=adapter_image,
610
+ width=width,
611
+ height=height,
612
+ batch_size=batch_size * num_images_per_prompt,
613
+ num_images_per_prompt=num_images_per_prompt,
614
+ device=device,
615
+ dtype=self.transformer.dtype,
616
+ )
617
+
618
+ # 4. Prepare latent variables
619
+ num_channels_latents = self.transformer.config.in_channels // 4
620
+ latents, latent_image_ids = self.prepare_latents(
621
+ batch_size * num_images_per_prompt,
622
+ num_layers,
623
+ num_channels_latents,
624
+ height,
625
+ width,
626
+ validation_box,
627
+ prompt_embeds.dtype,
628
+ device,
629
+ generator,
630
+ latents,
631
+ )
632
+
633
+ # 5. Prepare timesteps
634
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
635
+ image_seq_len = latent_image_ids.shape[0]
636
+ mu = calculate_shift(
637
+ image_seq_len,
638
+ self.scheduler.config.base_image_seq_len,
639
+ self.scheduler.config.max_image_seq_len,
640
+ self.scheduler.config.base_shift,
641
+ self.scheduler.config.max_shift,
642
+ )
643
+ timesteps, num_inference_steps = retrieve_timesteps(
644
+ self.scheduler,
645
+ num_inference_steps,
646
+ device,
647
+ timesteps,
648
+ sigmas,
649
+ mu=mu,
650
+ )
651
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
652
+ self._num_timesteps = len(timesteps)
653
+
654
+ # handle guidance
655
+ if self.transformer.config.guidance_embeds:
656
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
657
+ guidance = guidance.expand(latents.shape[0])
658
+ else:
659
+ guidance = None
660
+
661
+ # 6. Denoising loop
662
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
663
+ for i, t in enumerate(timesteps):
664
+ if self.interrupt:
665
+ continue
666
+
667
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
668
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
669
+
670
+ (
671
+ adapter_block_samples,
672
+ adapter_single_block_samples,
673
+ ) = self.multiLayerAdapter(
674
+ hidden_states=latents,
675
+ list_layer_box=validation_box,
676
+ adapter_cond=adapter_image,
677
+ conditioning_scale=adapter_conditioning_scale,
678
+ timestep=timestep / 1000,
679
+ guidance=guidance,
680
+ pooled_projections=pooled_prompt_embeds,
681
+ encoder_hidden_states=prompt_embeds,
682
+ txt_ids=text_ids,
683
+ img_ids=latent_image_ids,
684
+ joint_attention_kwargs=self.joint_attention_kwargs,
685
+ return_dict=False,
686
+ )
687
+
688
+ noise_pred = self.transformer(
689
+ hidden_states=latents,
690
+ list_layer_box=validation_box,
691
+ timestep=timestep / 1000,
692
+ guidance=guidance,
693
+ pooled_projections=pooled_prompt_embeds,
694
+ encoder_hidden_states=prompt_embeds,
695
+ adapter_block_samples=[
696
+ sample.to(dtype=self.transformer.dtype)
697
+ for sample in adapter_block_samples
698
+ ],
699
+ adapter_single_block_samples=[
700
+ sample.to(dtype=self.transformer.dtype)
701
+ for sample in adapter_single_block_samples
702
+ ] if adapter_single_block_samples is not None else adapter_single_block_samples,
703
+ txt_ids=text_ids,
704
+ img_ids=latent_image_ids,
705
+ joint_attention_kwargs=self.joint_attention_kwargs,
706
+ return_dict=False,
707
+ )[0]
708
+
709
+ neg_noise_pred = self.transformer(
710
+ hidden_states=latents,
711
+ list_layer_box=validation_box,
712
+ timestep=timestep / 1000,
713
+ guidance=guidance,
714
+ pooled_projections=neg_pooled_prompt_embeds,
715
+ encoder_hidden_states=neg_prompt_embeds,
716
+ adapter_block_samples=[
717
+ sample.to(dtype=self.transformer.dtype)
718
+ for sample in adapter_block_samples
719
+ ],
720
+ adapter_single_block_samples=[
721
+ sample.to(dtype=self.transformer.dtype)
722
+ for sample in adapter_single_block_samples
723
+ ] if adapter_single_block_samples is not None else adapter_single_block_samples,
724
+ txt_ids=neg_text_ids,
725
+ img_ids=latent_image_ids,
726
+ joint_attention_kwargs=self.joint_attention_kwargs,
727
+ return_dict=False,
728
+ )[0]
729
+
730
+ noise_pred = neg_noise_pred + true_gs * (noise_pred - neg_noise_pred)
731
+
732
+ # compute the previous noisy sample x_t -> x_t-1
733
+ latents_dtype = latents.dtype
734
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
735
+
736
+ if latents.dtype != latents_dtype:
737
+ if torch.backends.mps.is_available():
738
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
739
+ latents = latents.to(latents_dtype)
740
+
741
+ if callback_on_step_end is not None:
742
+ callback_kwargs = {}
743
+ for k in callback_on_step_end_tensor_inputs:
744
+ callback_kwargs[k] = locals()[k]
745
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
746
+
747
+ latents = callback_outputs.pop("latents", latents)
748
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
749
+
750
+ # call the callback, if provided
751
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
752
+ progress_bar.update()
753
+
754
+ if XLA_AVAILABLE:
755
+ xm.mark_step()
756
+
757
+ # create a grey latent
758
+ bs, n_frames, channel_latent, height, width = latents.shape
759
+
760
+ def encode_in_chunks(vae, images, chunk=8):
761
+ parts = []
762
+ for i in range(0, images.shape[0], chunk):
763
+ chunk_img = images[i : i + chunk]
764
+ part_latent = vae.encode(chunk_img).latent_dist.sample()
765
+ parts.append(part_latent)
766
+ torch.cuda.empty_cache()
767
+ return torch.cat(parts, dim=0)
768
+
769
+ pixel_grey = torch.zeros(size=(bs*n_frames, 3, height*8, width*8), device=latents.device, dtype=latents.dtype)
770
+ # latent_grey = self.vae.encode(pixel_grey).latent_dist.sample()
771
+ latent_grey = encode_in_chunks(self.vae, pixel_grey, chunk=16)
772
+ latent_grey = (latent_grey - self.vae.config.shift_factor) * self.vae.config.scaling_factor
773
+ latent_grey = latent_grey.view(bs, n_frames, channel_latent, height, width) # [bs, f, c_latent, h, w]
774
+
775
+ # fill in the latents
776
+ for layer_idx in range(latent_grey.shape[1]):
777
+ if validation_box[layer_idx] == None:
778
+ continue
779
+ x1, y1, x2, y2 = validation_box[layer_idx]
780
+ x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
781
+ latent_grey[:, layer_idx, :, y1:y2, x1:x2] = latents[:, layer_idx, :, y1:y2, x1:x2]
782
+ latents = latent_grey
783
+
784
+ if output_type == "latent":
785
+ image = latents
786
+
787
+ else:
788
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
789
+ bs, num_layers, c, h, w = latents.shape
790
+ latents = latents.reshape(bs * n_frames, channel_latent, height, width)
791
+ latents_segs = torch.split(latents, 8, dim=0)
792
+ image_segs = [self.vae.decode(latents_seg, return_dict=False)[0] for latents_seg in latents_segs]
793
+ image = torch.cat(image_segs, dim=0)
794
+ if sdxl_vae is not None:
795
+ sdxl_vae = sdxl_vae.to(dtype=image.dtype, device=image.device)
796
+
797
+ # Prepare input parameters
798
+ _, c1, h1, w1 = image.shape # Get channels and spatial dimensions from image
799
+ x = image.view(bs, num_layers, c1, h1, w1).permute(0, 2, 1, 3, 4).to(image.device) # Reshape to (bs, c, num_layers, h, w)
800
+ box = [validation_box] * bs # Create box info for each sample
801
+ use_layers = [list(range(len(b))) for b in box] # Use all layers
802
+ z_2d = latents.view(bs, num_layers, -1, h, w) # Reshape to (bs, num_layers, c, h, w)
803
+ z_2d = einops.rearrange(z_2d, "b t c h w -> b c t h w").to(image.device) # Reshape to (bs, c, num_layers, h, w)
804
+
805
+ # Call transparent VAE decoder
806
+ x_hat = sdxl_vae(x, box, use_layers, z_2d).to(x.dtype).clamp(-1, 1)
807
+ else:
808
+ result_list, vis_list = None, None
809
+ image = self.image_processor.postprocess(image, output_type=output_type)
810
+
811
+ # Offload all models
812
+ self.maybe_free_model_hooks()
813
+
814
+ if not return_dict:
815
+ return (image,)
816
+
817
+ return (
818
+ x_hat, # Final decoded result including foreground and transparency
819
+ image, # Final generated RGB image
820
+ latents # Latent variables
821
+ )