| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
| from typing import Any, Callable, Dict, List, Optional, Union |
|
|
| import torch |
| from transformers import ( |
| CLIPTextModelWithProjection, |
| CLIPTokenizer, |
| SiglipImageProcessor, |
| SiglipVisionModel, |
| T5EncoderModel, |
| T5TokenizerFast, |
| ) |
|
|
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
| from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin |
| from diffusers.models.autoencoders import AutoencoderKL |
| from diffusers.models.transformers import SD3Transformer2DModel |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
| from diffusers.utils import ( |
| USE_PEFT_BACKEND, |
| is_torch_xla_available, |
| logging, |
| replace_example_docstring, |
| scale_lora_layers, |
| unscale_lora_layers, |
| ) |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| from .pipeline_output import SiDPipelineOutput |
|
|
|
|
| if is_torch_xla_available(): |
| import torch_xla.core.xla_model as xm |
|
|
| XLA_AVAILABLE = True |
| else: |
| XLA_AVAILABLE = False |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| def calculate_shift( |
| image_seq_len, |
| base_seq_len: int = 256, |
| max_seq_len: int = 4096, |
| base_shift: float = 0.5, |
| max_shift: float = 1.15, |
| ): |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| b = base_shift - m * base_seq_len |
| mu = image_seq_len * m + b |
| return mu |
|
|
|
|
| |
| def retrieve_timesteps( |
| scheduler, |
| num_inference_steps: Optional[int] = None, |
| device: Optional[Union[str, torch.device]] = None, |
| timesteps: Optional[List[int]] = None, |
| sigmas: Optional[List[float]] = None, |
| **kwargs, |
| ): |
| r""" |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| |
| Args: |
| scheduler (`SchedulerMixin`): |
| The scheduler to get timesteps from. |
| num_inference_steps (`int`): |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
| must be `None`. |
| device (`str` or `torch.device`, *optional*): |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| timesteps (`List[int]`, *optional*): |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
| `num_inference_steps` and `sigmas` must be `None`. |
| sigmas (`List[float]`, *optional*): |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
| `num_inference_steps` and `timesteps` must be `None`. |
| |
| Returns: |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| second element is the number of inference steps. |
| """ |
| if timesteps is not None and sigmas is not None: |
| raise ValueError( |
| "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" |
| ) |
| if timesteps is not None: |
| accepts_timesteps = "timesteps" in set( |
| inspect.signature(scheduler.set_timesteps).parameters.keys() |
| ) |
| if not accepts_timesteps: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" timestep schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| elif sigmas is not None: |
| accept_sigmas = "sigmas" in set( |
| inspect.signature(scheduler.set_timesteps).parameters.keys() |
| ) |
| if not accept_sigmas: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" sigmas schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| else: |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| return timesteps, num_inference_steps |
|
|
|
|
| class SiDSD3Pipeline( |
| DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin |
| ): |
| r""" |
| Args: |
| transformer ([`SD3Transformer2DModel`]): |
| Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. |
| scheduler ([`FlowMatchEulerDiscreteScheduler`]): |
| A scheduler to be used in combination with `transformer` to denoise the encoded image latents. |
| vae ([`AutoencoderKL`]): |
| Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
| text_encoder ([`CLIPTextModelWithProjection`]): |
| [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), |
| specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, |
| with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` |
| as its dimension. |
| text_encoder_2 ([`CLIPTextModelWithProjection`]): |
| [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), |
| specifically the |
| [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) |
| variant. |
| text_encoder_3 ([`T5EncoderModel`]): |
| Frozen text-encoder. Stable Diffusion 3 uses |
| [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the |
| [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. |
| tokenizer (`CLIPTokenizer`): |
| Tokenizer of class |
| [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
| tokenizer_2 (`CLIPTokenizer`): |
| Second Tokenizer of class |
| [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
| tokenizer_3 (`T5TokenizerFast`): |
| Tokenizer of class |
| [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). |
| image_encoder (`SiglipVisionModel`, *optional*): |
| Pre-trained Vision Model for IP Adapter. |
| feature_extractor (`SiglipImageProcessor`, *optional*): |
| Image processor for IP Adapter. |
| """ |
|
|
| model_cpu_offload_seq = ( |
| "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" |
| ) |
| _optional_components = ["image_encoder", "feature_extractor"] |
| _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"] |
|
|
| def __init__( |
| self, |
| transformer: SD3Transformer2DModel, |
| scheduler: FlowMatchEulerDiscreteScheduler, |
| vae: AutoencoderKL, |
| text_encoder: CLIPTextModelWithProjection, |
| tokenizer: CLIPTokenizer, |
| text_encoder_2: CLIPTextModelWithProjection, |
| tokenizer_2: CLIPTokenizer, |
| text_encoder_3: T5EncoderModel, |
| tokenizer_3: T5TokenizerFast, |
| image_encoder: SiglipVisionModel = None, |
| feature_extractor: SiglipImageProcessor = None, |
| ): |
| super().__init__() |
|
|
| self.register_modules( |
| vae=vae, |
| text_encoder=text_encoder, |
| text_encoder_2=text_encoder_2, |
| text_encoder_3=text_encoder_3, |
| tokenizer=tokenizer, |
| tokenizer_2=tokenizer_2, |
| tokenizer_3=tokenizer_3, |
| transformer=transformer, |
| scheduler=scheduler, |
| image_encoder=image_encoder, |
| feature_extractor=feature_extractor, |
| ) |
| self.vae_scale_factor = ( |
| 2 ** (len(self.vae.config.block_out_channels) - 1) |
| if getattr(self, "vae", None) |
| else 8 |
| ) |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
| self.tokenizer_max_length = ( |
| self.tokenizer.model_max_length |
| if hasattr(self, "tokenizer") and self.tokenizer is not None |
| else 77 |
| ) |
| self.default_sample_size = ( |
| self.transformer.config.sample_size |
| if hasattr(self, "transformer") and self.transformer is not None |
| else 128 |
| ) |
| self.patch_size = ( |
| self.transformer.config.patch_size |
| if hasattr(self, "transformer") and self.transformer is not None |
| else 2 |
| ) |
|
|
| def _get_t5_prompt_embeds( |
| self, |
| prompt: Union[str, List[str]] = None, |
| num_images_per_prompt: int = 1, |
| max_sequence_length: int = 256, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| device = device or self._execution_device |
| dtype = dtype or self.text_encoder.dtype |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| batch_size = len(prompt) |
|
|
| if self.text_encoder_3 is None: |
| return torch.zeros( |
| ( |
| batch_size * num_images_per_prompt, |
| self.tokenizer_max_length, |
| self.transformer.config.joint_attention_dim, |
| ), |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| text_inputs = self.tokenizer_3( |
| prompt, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| add_special_tokens=True, |
| return_tensors="pt", |
| ) |
| text_input_ids = text_inputs.input_ids |
| untruncated_ids = self.tokenizer_3( |
| prompt, padding="longest", return_tensors="pt" |
| ).input_ids |
|
|
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
| text_input_ids, untruncated_ids |
| ): |
| removed_text = self.tokenizer_3.batch_decode( |
| untruncated_ids[:, self.tokenizer_max_length - 1 : -1] |
| ) |
| logger.warning( |
| "The following part of your input was truncated because `max_sequence_length` is set to " |
| f" {max_sequence_length} tokens: {removed_text}" |
| ) |
|
|
| prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] |
|
|
| dtype = self.text_encoder_3.dtype |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
| _, seq_len, _ = prompt_embeds.shape |
|
|
| |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view( |
| batch_size * num_images_per_prompt, seq_len, -1 |
| ) |
|
|
| return prompt_embeds |
|
|
| def _get_clip_prompt_embeds( |
| self, |
| prompt: Union[str, List[str]], |
| num_images_per_prompt: int = 1, |
| device: Optional[torch.device] = None, |
| clip_skip: Optional[int] = None, |
| clip_model_index: int = 0, |
| ): |
| device = device or self._execution_device |
|
|
| clip_tokenizers = [self.tokenizer, self.tokenizer_2] |
| clip_text_encoders = [self.text_encoder, self.text_encoder_2] |
|
|
| tokenizer = clip_tokenizers[clip_model_index] |
| text_encoder = clip_text_encoders[clip_model_index] |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| batch_size = len(prompt) |
|
|
| text_inputs = tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=self.tokenizer_max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| text_input_ids = text_inputs.input_ids |
| untruncated_ids = tokenizer( |
| prompt, padding="longest", return_tensors="pt" |
| ).input_ids |
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
| text_input_ids, untruncated_ids |
| ): |
| removed_text = tokenizer.batch_decode( |
| untruncated_ids[:, self.tokenizer_max_length - 1 : -1] |
| ) |
| logger.warning( |
| "The following part of your input was truncated because CLIP can only handle sequences up to" |
| f" {self.tokenizer_max_length} tokens: {removed_text}" |
| ) |
| prompt_embeds = text_encoder( |
| text_input_ids.to(device), output_hidden_states=True |
| ) |
| pooled_prompt_embeds = prompt_embeds[0] |
|
|
| if clip_skip is None: |
| prompt_embeds = prompt_embeds.hidden_states[-2] |
| else: |
| prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] |
|
|
| prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) |
|
|
| _, seq_len, _ = prompt_embeds.shape |
| |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view( |
| batch_size * num_images_per_prompt, seq_len, -1 |
| ) |
|
|
| pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| pooled_prompt_embeds = pooled_prompt_embeds.view( |
| batch_size * num_images_per_prompt, -1 |
| ) |
|
|
| return prompt_embeds, pooled_prompt_embeds |
|
|
| def encode_prompt( |
| self, |
| prompt: Union[str, List[str]], |
| prompt_2: Union[str, List[str]], |
| prompt_3: Union[str, List[str]], |
| device: Optional[torch.device] = None, |
| num_images_per_prompt: int = 1, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| clip_skip: Optional[int] = None, |
| max_sequence_length: int = 256, |
| ): |
| r""" |
| |
| Args: |
| prompt (`str` or `List[str]`, *optional*): |
| prompt to be encoded |
| prompt_2 (`str` or `List[str]`, *optional*): |
| The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
| used in all text-encoders |
| prompt_3 (`str` or `List[str]`, *optional*): |
| The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is |
| used in all text-encoders |
| device: (`torch.device`): |
| torch device |
| num_images_per_prompt (`int`): |
| number of images that should be generated per prompt |
| do_classifier_free_guidance (`bool`): |
| whether to use classifier free guidance or not |
| negative_prompt (`str` or `List[str]`, *optional*): |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass |
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
| less than `1`). |
| negative_prompt_2 (`str` or `List[str]`, *optional*): |
| The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and |
| `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. |
| negative_prompt_3 (`str` or `List[str]`, *optional*): |
| The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and |
| `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. |
| prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
| provided, text embeddings will be generated from `prompt` input argument. |
| negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
| argument. |
| pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
| If not provided, pooled text embeddings will be generated from `prompt` input argument. |
| negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` |
| input argument. |
| clip_skip (`int`, *optional*): |
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
| the output of the pre-final layer will be used for computing the prompt embeddings. |
| lora_scale (`float`, *optional*): |
| A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
| """ |
| device = device or self._execution_device |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| if prompt is not None: |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| if prompt_embeds is None: |
| prompt_2 = prompt_2 or prompt |
| prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 |
|
|
| prompt_3 = prompt_3 or prompt |
| prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 |
|
|
| prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( |
| prompt=prompt, |
| device=device, |
| num_images_per_prompt=num_images_per_prompt, |
| clip_skip=clip_skip, |
| clip_model_index=0, |
| ) |
| prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( |
| prompt=prompt_2, |
| device=device, |
| num_images_per_prompt=num_images_per_prompt, |
| clip_skip=clip_skip, |
| clip_model_index=1, |
| ) |
| clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) |
|
|
| t5_prompt_embed = self._get_t5_prompt_embeds( |
| prompt=prompt_3, |
| num_images_per_prompt=num_images_per_prompt, |
| max_sequence_length=max_sequence_length, |
| device=device, |
| ) |
|
|
| clip_prompt_embeds = torch.nn.functional.pad( |
| clip_prompt_embeds, |
| (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), |
| ) |
|
|
| prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) |
| pooled_prompt_embeds = torch.cat( |
| [pooled_prompt_embed, pooled_prompt_2_embed], dim=-1 |
| ) |
|
|
| return ( |
| prompt_embeds, |
| pooled_prompt_embeds, |
| ) |
|
|
| def check_inputs( |
| self, |
| prompt, |
| prompt_2, |
| prompt_3, |
| height, |
| width, |
| negative_prompt=None, |
| negative_prompt_2=None, |
| negative_prompt_3=None, |
| prompt_embeds=None, |
| negative_prompt_embeds=None, |
| pooled_prompt_embeds=None, |
| negative_pooled_prompt_embeds=None, |
| callback_on_step_end_tensor_inputs=None, |
| max_sequence_length=None, |
| ): |
| if ( |
| height % (self.vae_scale_factor * self.patch_size) != 0 |
| or width % (self.vae_scale_factor * self.patch_size) != 0 |
| ): |
| raise ValueError( |
| f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." |
| f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." |
| ) |
|
|
| if callback_on_step_end_tensor_inputs is not None and not all( |
| k in self._callback_tensor_inputs |
| for k in callback_on_step_end_tensor_inputs |
| ): |
| raise ValueError( |
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" |
| ) |
|
|
| if prompt is not None and prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| " only forward one of the two." |
| ) |
| elif prompt_2 is not None and prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| " only forward one of the two." |
| ) |
| elif prompt_3 is not None and prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| " only forward one of the two." |
| ) |
| elif prompt is None and prompt_embeds is None: |
| raise ValueError( |
| "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
| ) |
| elif prompt is not None and ( |
| not isinstance(prompt, str) and not isinstance(prompt, list) |
| ): |
| raise ValueError( |
| f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" |
| ) |
| elif prompt_2 is not None and ( |
| not isinstance(prompt_2, str) and not isinstance(prompt_2, list) |
| ): |
| raise ValueError( |
| f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" |
| ) |
| elif prompt_3 is not None and ( |
| not isinstance(prompt_3, str) and not isinstance(prompt_3, list) |
| ): |
| raise ValueError( |
| f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}" |
| ) |
|
|
| if negative_prompt is not None and negative_prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
| ) |
| elif negative_prompt_2 is not None and negative_prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" |
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
| ) |
| elif negative_prompt_3 is not None and negative_prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" |
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
| ) |
|
|
| if prompt_embeds is not None and negative_prompt_embeds is not None: |
| if prompt_embeds.shape != negative_prompt_embeds.shape: |
| raise ValueError( |
| "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
| f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" |
| f" {negative_prompt_embeds.shape}." |
| ) |
|
|
| if prompt_embeds is not None and pooled_prompt_embeds is None: |
| raise ValueError( |
| "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." |
| ) |
|
|
| if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: |
| raise ValueError( |
| "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." |
| ) |
|
|
| if max_sequence_length is not None and max_sequence_length > 512: |
| raise ValueError( |
| f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}" |
| ) |
|
|
| def prepare_latents( |
| self, |
| batch_size, |
| num_channels_latents, |
| height, |
| width, |
| dtype, |
| device, |
| generator, |
| latents=None, |
| ): |
| if latents is not None: |
| return latents.to(device=device, dtype=dtype) |
|
|
| shape = ( |
| batch_size, |
| num_channels_latents, |
| int(height) // self.vae_scale_factor, |
| int(width) // self.vae_scale_factor, |
| ) |
|
|
| if isinstance(generator, list) and len(generator) != batch_size: |
| raise ValueError( |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| ) |
|
|
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
| return latents |
|
|
| @property |
| def guidance_scale(self): |
| return self._guidance_scale |
|
|
| @property |
| def skip_guidance_layers(self): |
| return self._skip_guidance_layers |
|
|
| @property |
| def clip_skip(self): |
| return self._clip_skip |
|
|
| |
| |
| |
| @property |
| def do_classifier_free_guidance(self): |
| return self._guidance_scale > 1 |
|
|
| @property |
| def joint_attention_kwargs(self): |
| return self._joint_attention_kwargs |
|
|
| @property |
| def num_timesteps(self): |
| return self._num_timesteps |
|
|
| @property |
| def interrupt(self): |
| return self._interrupt |
|
|
| |
|
|
| def enable_sequential_cpu_offload(self, *args, **kwargs): |
| if ( |
| self.image_encoder is not None |
| and "image_encoder" not in self._exclude_from_cpu_offload |
| ): |
| logger.warning( |
| "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " |
| "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " |
| "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." |
| ) |
|
|
| super().enable_sequential_cpu_offload(*args, **kwargs) |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt: Union[str, List[str]] = None, |
| prompt_2: Optional[Union[str, List[str]]] = None, |
| prompt_3: Optional[Union[str, List[str]]] = None, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| num_inference_steps: int = 28, |
| guidance_scale: float = 1.0, |
| num_images_per_prompt: Optional[int] = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.FloatTensor] = None, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| max_sequence_length: int = 256, |
| use_sd3_shift: bool = False, |
| noise_type: str = "fresh", |
| time_scale: float = 1000.0, |
| ): |
| height = height or self.default_sample_size * self.vae_scale_factor |
| width = width or self.default_sample_size * self.vae_scale_factor |
|
|
| if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): |
| callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs |
|
|
| |
| self.check_inputs( |
| prompt, |
| prompt_2, |
| prompt_3, |
| height, |
| width, |
| prompt_embeds=prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| self._guidance_scale = guidance_scale |
| self._interrupt = False |
|
|
| |
| if prompt is not None and isinstance(prompt, str): |
| batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| device = self._execution_device |
|
|
| ( |
| prompt_embeds, |
| pooled_prompt_embeds, |
| ) = self.encode_prompt( |
| prompt, |
| prompt_2, |
| prompt_3, |
| prompt_embeds=prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| device=device, |
| num_images_per_prompt=num_images_per_prompt, |
| max_sequence_length=max_sequence_length, |
| ) |
| |
| num_channels_latents = self.transformer.config.in_channels |
| latents = self.prepare_latents( |
| batch_size * num_images_per_prompt, |
| num_channels_latents, |
| height, |
| width, |
| prompt_embeds.dtype, |
| device, |
| generator, |
| latents, |
| ) |
|
|
| |
| |
| D_x = torch.zeros_like(latents).to(latents.device) |
| |
| initial_latents = latents.clone() if noise_type == 'fixed' else None |
| for i in range(num_inference_steps): |
| if noise_type == "fresh": |
| noise = ( |
| latents if i == 0 else torch.randn_like(latents).to(latents.device) |
| ) |
| elif noise_type == "ddim": |
| noise = ( |
| latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach() |
| ) |
| elif noise_type == "fixed": |
| noise = initial_latents |
| else: |
| raise ValueError(f"Unknown noise_type: {noise_type}") |
|
|
| |
| init_timesteps = 999 |
| scalar_t = float(init_timesteps) * ( |
| 1.0 - float(i) / float(num_inference_steps) |
| ) |
| t_val = scalar_t / 999.0 |
| |
| if use_sd3_shift: |
| shift = 3.0 |
| t_val = shift * t_val / (1 + (shift - 1) * t_val) |
|
|
| t = torch.full( |
| (latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype |
| ) |
| t_flattern = t.flatten() |
| if t.numel() > 1: |
| t = t.view(-1, 1, 1, 1) |
|
|
| latents = (1.0 - t) * D_x + t * noise |
| latent_model_input = latents |
|
|
| flow_pred = self.transformer( |
| hidden_states=latent_model_input, |
| encoder_hidden_states=prompt_embeds, |
| |
| pooled_projections=pooled_prompt_embeds, |
| timestep=time_scale * t_flattern, |
| return_dict=False, |
| )[0] |
| D_x = latents - ( |
| t * flow_pred |
| if torch.numel(t) == 1 |
| else t.view(-1, 1, 1, 1) * flow_pred |
| ) |
|
|
| |
| image = self.vae.decode( |
| (D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor, |
| return_dict=False, |
| )[0] |
| image = self.image_processor.postprocess(image, output_type=output_type) |
| |
| self.maybe_free_model_hooks() |
| |
|
|
| |
| if not return_dict: |
| return (image,) |
|
|
| return SiDPipelineOutput(images=image) |
|
|