| import itertools |
| from functools import partial |
| from typing import Any, Dict, Tuple, Callable |
| from typing import Union, Optional, List |
|
|
| import numpy as np |
| import torch |
| from diffusers import DPMSolverMultistepScheduler |
| from diffusers import StableDiffusionPipeline, AutoencoderKL |
| from diffusers import Transformer2DModel, ModelMixin, ConfigMixin, SchedulerMixin |
| from diffusers import UNet2DConditionModel |
| from diffusers.configuration_utils import register_to_config |
| from diffusers.models.attention import BasicTransformerBlock |
| from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D |
| from diffusers.models.transformer_2d import Transformer2DModelOutput |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput |
| from diffusers.schedulers import KarrasDiffusionSchedulers |
| from diffusers.utils import replace_example_docstring |
| from torch import nn |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
|
|
|
|
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
| """ |
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
| """ |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
| |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
| |
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
| return noise_cfg |
|
|
|
|
| def custom_sort_order(obj): |
| """ |
| Key function for sorting order of execution in forward methods |
| """ |
| return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__) |
|
|
|
|
| def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing): |
| """ |
| :param timestep_spacing: the timestep_spacing array we want to squeeze |
| :param n: the size of the squeezed array |
| :param i: the index we start squeezing from |
| :return: squeezed timestep_spacing |
| Example: |
| timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16) |
| n = 10, i = 6 |
| Expected: |
| [967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133 |
| """ |
| assert i < n |
| squeezed = np.flip(np.arange(n)) + 1 |
| squeezed[:i] = timestep_spacing[:i] |
| k = squeezed[i - 1] // (n - i + 1) |
| squeezed[i:] *= k |
|
|
| return squeezed |
|
|
|
|
| PREDEFINED_TIMESTEP_SQUEEZERS = { |
| |
| "10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6), |
| "11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7), |
| } |
|
|
| FlexibleUnetConfigurations = { |
| |
| "sample_size": 64, |
| "temb_dim": 320 * 4, |
| "resnet_eps": 1e-5, |
| "resnet_act_fn": "silu", |
| "num_attention_heads": 8, |
| "cross_attention_dim": 768, |
| |
| "mix_block_in_forward": True, |
| |
| "down_blocks_in_channels": [320, 320, 640], |
| "down_blocks_out_channels": [320, 640, 1280], |
| "down_blocks_num_attentions": [0, 1, 3], |
| "down_blocks_num_resnets": [2, 2, 1], |
| "add_downsample": [True, True, False], |
| |
| "add_upsample_mid_block": None, |
| "mid_num_resnets": 0, |
| "mid_num_attentions": 0, |
| |
| "prev_output_channels": [1280, 1280, 640], |
| "up_blocks_num_attentions": [5, 3, 0], |
| "up_blocks_num_resnets": [2, 3, 3], |
| "add_upsample": [True, True, False], |
| } |
|
|
|
|
| class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler, SchedulerMixin): |
| """ |
| This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences: |
| * Defaults are modified to accommodate DeciDiffusion |
| * It supports a squeezer to squeeze the number of inference steps to a smaller number |
| //!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline! |
| """ |
|
|
| @register_to_config |
| def __init__( |
| self, |
| num_train_timesteps: int = 1000, |
| beta_start: float = 0.0001, |
| beta_end: float = 0.02, |
| beta_schedule: str = "squaredcos_cap_v2", |
| trained_betas: Optional[Union[np.ndarray, List[float]]] = None, |
| solver_order: int = 2, |
| prediction_type: str = "v_prediction", |
| thresholding: bool = False, |
| dynamic_thresholding_ratio: float = 0.995, |
| sample_max_value: float = 1.0, |
| algorithm_type: str = "dpmsolver++", |
| solver_type: str = "heun", |
| lower_order_final: bool = True, |
| use_karras_sigmas: Optional[bool] = False, |
| lambda_min_clipped: float = -7.5, |
| variance_type: Optional[str] = None, |
| timestep_spacing: str = "linspace", |
| steps_offset: int = 1, |
| squeeze_mode: Optional[str] = None, |
| ): |
| self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode) |
|
|
| if use_karras_sigmas: |
| raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`") |
|
|
| super().__init__( |
| num_train_timesteps=num_train_timesteps, |
| beta_start=beta_start, |
| beta_end=beta_end, |
| beta_schedule=beta_schedule, |
| trained_betas=trained_betas, |
| solver_order=solver_order, |
| prediction_type=prediction_type, |
| thresholding=thresholding, |
| dynamic_thresholding_ratio=dynamic_thresholding_ratio, |
| sample_max_value=sample_max_value, |
| algorithm_type=algorithm_type, |
| solver_type=solver_type, |
| lower_order_final=lower_order_final, |
| use_karras_sigmas=False, |
| lambda_min_clipped=lambda_min_clipped, |
| variance_type=variance_type, |
| timestep_spacing=timestep_spacing, |
| steps_offset=steps_offset, |
| ) |
|
|
| def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): |
| """ |
| Sets the discrete timesteps used for the diffusion chain (to be run before inference). |
| |
| Args: |
| num_inference_steps (`int`): |
| The number of diffusion steps used when generating samples with a pre-trained model. |
| device (`str` or `torch.device`, *optional*): |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| """ |
| super().set_timesteps(num_inference_steps=num_inference_steps, device=device) |
| if self._squeezer is not None: |
| timesteps = self._squeezer(self.timesteps.cpu()) |
| sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) |
| sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) |
| sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 |
| sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) |
| self.sigmas = torch.from_numpy(sigmas) |
| self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) |
| self.num_inference_steps = len(timesteps) |
|
|
|
|
| class FlexibleIdentityBlock(nn.Module): |
| def forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| temb: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| ): |
| return hidden_states |
|
|
|
|
| class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin): |
| configurations = FlexibleUnetConfigurations |
|
|
| @register_to_config |
| def __init__(self): |
| super().__init__( |
| sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]), |
| cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]), |
| ) |
|
|
| num_attention_heads = self.configurations.get("num_attention_heads") |
| cross_attention_dim = self.configurations.get("cross_attention_dim") |
| mix_block_in_forward = self.configurations.get("mix_block_in_forward") |
| resnet_act_fn = self.configurations.get("resnet_act_fn") |
| resnet_eps = self.configurations.get("resnet_eps") |
| temb_dim = self.configurations.get("temb_dim") |
|
|
| |
| |
| |
| down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions") |
| down_blocks_out_channels = self.configurations.get("down_blocks_out_channels") |
| down_blocks_in_channels = self.configurations.get("down_blocks_in_channels") |
| down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets") |
| add_downsample = self.configurations.get("add_downsample") |
|
|
| self.down_blocks = nn.ModuleList() |
|
|
| for i, (in_c, out_c, n_res, n_att, add_down) in enumerate( |
| zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample) |
| ): |
| last_block = i == len(down_blocks_in_channels) - 1 |
| self.down_blocks.append( |
| FlexibleCrossAttnDownBlock2D( |
| in_channels=in_c, |
| out_channels=out_c, |
| temb_channels=temb_dim, |
| num_resnets=n_res, |
| num_attentions=n_att, |
| resnet_eps=resnet_eps, |
| resnet_act_fn=resnet_act_fn, |
| num_attention_heads=num_attention_heads, |
| cross_attention_dim=cross_attention_dim, |
| add_downsample=add_down, |
| last_block=last_block, |
| mix_block_in_forward=mix_block_in_forward, |
| ) |
| ) |
|
|
| |
| |
| |
|
|
| mid_block_add_upsample = self.configurations.get("add_upsample_mid_block") |
| mid_num_attentions = self.configurations.get("mid_num_attentions") |
| mid_num_resnets = self.configurations.get("mid_num_resnets") |
|
|
| if mid_num_resnets == mid_num_attentions == 0: |
| self.mid_block = FlexibleIdentityBlock() |
| else: |
| self.mid_block = FlexibleUNetMidBlock2DCrossAttn( |
| in_channels=down_blocks_out_channels[-1], |
| temb_channels=temb_dim, |
| resnet_act_fn=resnet_act_fn, |
| resnet_eps=resnet_eps, |
| cross_attention_dim=cross_attention_dim, |
| num_attention_heads=num_attention_heads, |
| num_resnets=mid_num_resnets, |
| num_attentions=mid_num_attentions, |
| mix_block_in_forward=mix_block_in_forward, |
| add_upsample=mid_block_add_upsample, |
| ) |
|
|
| |
| |
| |
|
|
| up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions") |
| up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets") |
| prev_output_channels = self.configurations.get("prev_output_channels") |
| up_upsample = self.configurations.get("add_upsample") |
|
|
| self.up_blocks = nn.ModuleList() |
| for in_c, out_c, prev_out, n_res, n_att, add_up in zip( |
| reversed(down_blocks_in_channels), |
| reversed(down_blocks_out_channels), |
| prev_output_channels, |
| up_blocks_num_resnets, |
| up_blocks_num_attentions, |
| up_upsample, |
| ): |
| self.up_blocks.append( |
| FlexibleCrossAttnUpBlock2D( |
| in_channels=in_c, |
| out_channels=out_c, |
| prev_output_channel=prev_out, |
| temb_channels=temb_dim, |
| num_resnets=n_res, |
| num_attentions=n_att, |
| resnet_eps=resnet_eps, |
| resnet_act_fn=resnet_act_fn, |
| num_attention_heads=num_attention_heads, |
| cross_attention_dim=cross_attention_dim, |
| add_upsample=add_up, |
| mix_block_in_forward=mix_block_in_forward, |
| ) |
| ) |
|
|
|
|
| class FlexibleCrossAttnDownBlock2D(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| temb_channels: int, |
| dropout: float = 0.0, |
| num_resnets: int = 1, |
| num_attentions: int = 1, |
| transformer_layers_per_block: int = 1, |
| resnet_eps: float = 1e-6, |
| resnet_time_scale_shift: str = "default", |
| resnet_act_fn: str = "swish", |
| resnet_groups: int = 32, |
| resnet_pre_norm: bool = True, |
| num_attention_heads: int = 1, |
| cross_attention_dim: int = 1280, |
| output_scale_factor: float = 1.0, |
| downsample_padding: int = 1, |
| add_downsample: bool = True, |
| use_linear_projection: bool = False, |
| only_cross_attention: bool = False, |
| upcast_attention: bool = False, |
| last_block: bool = False, |
| mix_block_in_forward: bool = True, |
| ): |
| super().__init__() |
|
|
| self.last_block = last_block |
| self.mix_block_in_forward = mix_block_in_forward |
| self.has_cross_attention = True |
| self.num_attention_heads = num_attention_heads |
|
|
| modules = [] |
|
|
| add_resnets = [True] * num_resnets |
| add_cross_attentions = [True] * num_attentions |
| for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| in_channels = in_channels if i == 0 else out_channels |
| if add_resnet: |
| modules.append( |
| ResnetBlock2D( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| eps=resnet_eps, |
| groups=resnet_groups, |
| dropout=dropout, |
| time_embedding_norm=resnet_time_scale_shift, |
| non_linearity=resnet_act_fn, |
| output_scale_factor=output_scale_factor, |
| pre_norm=resnet_pre_norm, |
| ) |
| ) |
| if add_cross_attention: |
| modules.append( |
| FlexibleTransformer2DModel( |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=out_channels // num_attention_heads, |
| in_channels=out_channels, |
| num_layers=transformer_layers_per_block, |
| cross_attention_dim=cross_attention_dim, |
| norm_num_groups=resnet_groups, |
| use_linear_projection=use_linear_projection, |
| only_cross_attention=only_cross_attention, |
| upcast_attention=upcast_attention, |
| ) |
| ) |
|
|
| if not mix_block_in_forward: |
| modules = sorted(modules, key=custom_sort_order) |
|
|
| self.modules_list = nn.ModuleList(modules) |
|
|
| if add_downsample: |
| self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")]) |
| else: |
| self.downsamplers = None |
|
|
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| temb: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| ): |
| output_states = () |
|
|
| for module in self.modules_list: |
| if isinstance(module, ResnetBlock2D): |
| hidden_states = module(hidden_states, temb) |
| elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| hidden_states = module( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| attention_mask=attention_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| return_dict=False, |
| )[0] |
| else: |
| raise ValueError(f"Got an unexpected module in modules list! {type(module)}") |
| if isinstance(module, ResnetBlock2D): |
| output_states = output_states + (hidden_states,) |
|
|
| if self.downsamplers is not None: |
| for downsampler in self.downsamplers: |
| hidden_states = downsampler(hidden_states) |
|
|
| if not self.last_block: |
| output_states = output_states + (hidden_states,) |
|
|
| return hidden_states, output_states |
|
|
|
|
| class FlexibleCrossAttnUpBlock2D(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| prev_output_channel: int, |
| temb_channels: int, |
| dropout: float = 0.0, |
| num_resnets: int = 1, |
| num_attentions: int = 1, |
| transformer_layers_per_block: int = 1, |
| resnet_eps: float = 1e-6, |
| resnet_time_scale_shift: str = "default", |
| resnet_act_fn: str = "swish", |
| resnet_groups: int = 32, |
| resnet_pre_norm: bool = True, |
| num_attention_heads: int = 1, |
| cross_attention_dim: int = 1280, |
| output_scale_factor: float = 1.0, |
| add_upsample: bool = True, |
| use_linear_projection: bool = False, |
| only_cross_attention: bool = False, |
| upcast_attention: bool = False, |
| mix_block_in_forward: bool = True, |
| ): |
| super().__init__() |
| modules = [] |
|
|
| |
| self.resnets = [] |
|
|
| self.has_cross_attention = True |
| self.num_attention_heads = num_attention_heads |
|
|
| add_resnets = [True] * num_resnets |
| add_cross_attentions = [True] * num_attentions |
| for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
| if add_resnet: |
| self.resnets += [True] |
| modules.append( |
| ResnetBlock2D( |
| in_channels=resnet_in_channels + res_skip_channels, |
| out_channels=out_channels, |
| temb_channels=temb_channels, |
| eps=resnet_eps, |
| groups=resnet_groups, |
| dropout=dropout, |
| time_embedding_norm=resnet_time_scale_shift, |
| non_linearity=resnet_act_fn, |
| output_scale_factor=output_scale_factor, |
| pre_norm=resnet_pre_norm, |
| ) |
| ) |
| if add_cross_attention: |
| modules.append( |
| FlexibleTransformer2DModel( |
| num_attention_heads, |
| out_channels // num_attention_heads, |
| in_channels=out_channels, |
| num_layers=transformer_layers_per_block, |
| cross_attention_dim=cross_attention_dim, |
| norm_num_groups=resnet_groups, |
| use_linear_projection=use_linear_projection, |
| only_cross_attention=only_cross_attention, |
| upcast_attention=upcast_attention, |
| ) |
| ) |
|
|
| if not mix_block_in_forward: |
| modules = sorted(modules, key=custom_sort_order) |
|
|
| self.modules_list = nn.ModuleList(modules) |
|
|
| self.upsamplers = None |
| if add_upsample: |
| self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) |
|
|
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], |
| temb: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| upsample_size: Optional[int] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| ): |
|
|
| for module in self.modules_list: |
| if isinstance(module, ResnetBlock2D): |
| res_hidden_states = res_hidden_states_tuple[-1] |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
| hidden_states = module(hidden_states, temb) |
| if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| hidden_states = module( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| attention_mask=attention_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| return_dict=False, |
| )[0] |
|
|
| if self.upsamplers is not None: |
| for upsampler in self.upsamplers: |
| hidden_states = upsampler(hidden_states, upsample_size) |
|
|
| return hidden_states |
|
|
|
|
| class FlexibleUNetMidBlock2DCrossAttn(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| temb_channels: int, |
| dropout: float = 0.0, |
| num_resnets: int = 1, |
| num_attentions: int = 1, |
| transformer_layers_per_block: int = 1, |
| resnet_eps: float = 1e-6, |
| resnet_time_scale_shift: str = "default", |
| resnet_act_fn: str = "swish", |
| resnet_groups: int = 32, |
| resnet_pre_norm: bool = True, |
| num_attention_heads: int = 1, |
| output_scale_factor: float = 1.0, |
| cross_attention_dim: int = 1280, |
| use_linear_projection: bool = False, |
| upcast_attention: bool = False, |
| mix_block_in_forward: bool = True, |
| add_upsample: bool = True, |
| ): |
| super().__init__() |
|
|
| self.has_cross_attention = True |
| self.num_attention_heads = num_attention_heads |
| resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
| |
| modules = [ |
| ResnetBlock2D( |
| in_channels=in_channels, |
| out_channels=in_channels, |
| temb_channels=temb_channels, |
| eps=resnet_eps, |
| groups=resnet_groups, |
| dropout=dropout, |
| time_embedding_norm=resnet_time_scale_shift, |
| non_linearity=resnet_act_fn, |
| output_scale_factor=output_scale_factor, |
| pre_norm=resnet_pre_norm, |
| ) |
| ] |
|
|
| add_resnets = [True] * num_resnets |
| add_cross_attentions = [True] * num_attentions |
| for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| if add_cross_attention: |
| modules.append( |
| FlexibleTransformer2DModel( |
| num_attention_heads, |
| in_channels // num_attention_heads, |
| in_channels=in_channels, |
| num_layers=transformer_layers_per_block, |
| cross_attention_dim=cross_attention_dim, |
| norm_num_groups=resnet_groups, |
| use_linear_projection=use_linear_projection, |
| upcast_attention=upcast_attention, |
| ) |
| ) |
|
|
| if add_resnet: |
| modules.append( |
| ResnetBlock2D( |
| in_channels=in_channels, |
| out_channels=in_channels, |
| temb_channels=temb_channels, |
| eps=resnet_eps, |
| groups=resnet_groups, |
| dropout=dropout, |
| time_embedding_norm=resnet_time_scale_shift, |
| non_linearity=resnet_act_fn, |
| output_scale_factor=output_scale_factor, |
| pre_norm=resnet_pre_norm, |
| ) |
| ) |
| if not mix_block_in_forward: |
| modules = sorted(modules, key=custom_sort_order) |
|
|
| self.modules_list = nn.ModuleList(modules) |
|
|
| self.upsamplers = nn.ModuleList([nn.Identity()]) |
| if add_upsample: |
| self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) |
|
|
| def forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| temb: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| ) -> torch.FloatTensor: |
| hidden_states = self.modules_list[0](hidden_states, temb) |
|
|
| for module in self.modules_list: |
| if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| hidden_states = module( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| attention_mask=attention_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| return_dict=False, |
| )[0] |
| elif isinstance(module, ResnetBlock2D): |
| hidden_states = module(hidden_states, temb) |
|
|
| for upsampler in self.upsamplers: |
| hidden_states = upsampler(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class FlexibleTransformer2DModel(ModelMixin, ConfigMixin): |
| @register_to_config |
| def __init__( |
| self, |
| num_attention_heads: int = 16, |
| attention_head_dim: int = 88, |
| in_channels: Optional[int] = None, |
| out_channels: Optional[int] = None, |
| num_layers: int = 1, |
| dropout: float = 0.0, |
| norm_num_groups: int = 32, |
| cross_attention_dim: Optional[int] = None, |
| attention_bias: bool = False, |
| activation_fn: str = "geglu", |
| num_embeds_ada_norm: Optional[int] = None, |
| only_cross_attention: bool = False, |
| use_linear_projection: bool = False, |
| upcast_attention: bool = False, |
| norm_type: str = "layer_norm", |
| norm_elementwise_affine: bool = True, |
| ): |
| super().__init__() |
| self.num_attention_heads = num_attention_heads |
| self.attention_head_dim = attention_head_dim |
| self.in_channels = in_channels |
| inner_dim = num_attention_heads * attention_head_dim |
|
|
| |
| self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) |
| self.use_linear_projection = use_linear_projection |
| if self.use_linear_projection: |
| self.proj_in = nn.Linear(in_channels, inner_dim) |
| else: |
| self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| inner_dim, |
| num_attention_heads, |
| attention_head_dim, |
| dropout=dropout, |
| cross_attention_dim=cross_attention_dim, |
| activation_fn=activation_fn, |
| num_embeds_ada_norm=num_embeds_ada_norm, |
| attention_bias=attention_bias, |
| only_cross_attention=only_cross_attention, |
| upcast_attention=upcast_attention, |
| norm_type=norm_type, |
| norm_elementwise_affine=norm_elementwise_affine, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| |
| self.out_channels = in_channels if out_channels is None else out_channels |
| if self.use_linear_projection: |
| self.proj_out = nn.Linear(inner_dim, in_channels) |
| else: |
| self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| timestep: Optional[torch.LongTensor] = None, |
| class_labels: Optional[torch.LongTensor] = None, |
| cross_attention_kwargs: Dict[str, Any] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| return_dict: bool = False, |
| ): |
| |
| batch, _, height, width = hidden_states.shape |
| residual = hidden_states |
|
|
| hidden_states = self.norm(hidden_states) |
| if not self.use_linear_projection: |
| hidden_states = self.proj_in(hidden_states) |
| inner_dim = hidden_states.shape[1] |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| else: |
| inner_dim = hidden_states.shape[1] |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| hidden_states = self.proj_in(hidden_states) |
|
|
| |
| for block in self.transformer_blocks: |
| hidden_states = block( |
| hidden_states, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| timestep=timestep, |
| cross_attention_kwargs=cross_attention_kwargs, |
| class_labels=class_labels, |
| ) |
|
|
| |
| if not self.use_linear_projection: |
| hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
| hidden_states = self.proj_out(hidden_states) |
| else: |
| hidden_states = self.proj_out(hidden_states) |
| hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
|
|
| output = hidden_states + residual |
| if return_dict: |
| return (output,) |
| return Transformer2DModelOutput(sample=output) |
|
|
|
|
| class DeciDiffusionPipeline(StableDiffusionPipeline): |
| deci_default_squeeze_mode = "10,6" |
| deci_default_number_of_iterations = 16 |
| deci_default_guidance_rescale = 0.8 |
|
|
| def __init__( |
| self, |
| vae: AutoencoderKL, |
| text_encoder: CLIPTextModel, |
| tokenizer: CLIPTokenizer, |
| unet: UNet2DConditionModel, |
| scheduler: KarrasDiffusionSchedulers, |
| safety_checker: StableDiffusionSafetyChecker, |
| feature_extractor: CLIPImageProcessor, |
| requires_safety_checker: bool = True, |
| ): |
| |
| del unet |
| unet = FlexibleUNet2DConditionModel() |
|
|
| |
| del scheduler |
| scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode) |
|
|
| super().__init__( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| unet=unet, |
| scheduler=scheduler, |
| safety_checker=safety_checker, |
| feature_extractor=feature_extractor, |
| requires_safety_checker=requires_safety_checker, |
| ) |
|
|
| self.register_modules( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| unet=unet, |
| scheduler=scheduler, |
| safety_checker=safety_checker, |
| feature_extractor=feature_extractor, |
| ) |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt: Union[str, List[str]] = None, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| num_inference_steps: int = 16, |
| guidance_scale: float = 7.5, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| num_images_per_prompt: Optional[int] = 1, |
| eta: float = 0.0, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.FloatTensor] = None, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| callback_steps: int = 1, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| guidance_rescale: float = 0.8, |
| ): |
| r""" |
| The call function to the pipeline for generation. |
| |
| Args: |
| prompt (`str` or `List[str]`, *optional*): |
| The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
| height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
| The height in pixels of the generated image. |
| width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
| The width in pixels of the generated image. |
| num_inference_steps (`int`, *optional*, defaults to 50): |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| expense of slower inference. |
| guidance_scale (`float`, *optional*, defaults to 7.5): |
| A higher guidance scale value encourages the model to generate images closely linked to the text |
| `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
| negative_prompt (`str` or `List[str]`, *optional*): |
| The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
| pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
| num_images_per_prompt (`int`, *optional*, defaults to 1): |
| The number of images to generate per prompt. |
| eta (`float`, *optional*, defaults to 0.0): |
| Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies |
| to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
| generation deterministic. |
| latents (`torch.FloatTensor`, *optional*): |
| Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| tensor is generated by sampling using the supplied random `generator`. |
| prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
| provided, text embeddings are generated from the `prompt` input argument. |
| negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
| not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
| output_type (`str`, *optional*, defaults to `"pil"`): |
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| plain tuple. |
| callback (`Callable`, *optional*): |
| A function that calls every `callback_steps` steps during inference. The function is called with the |
| following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
| callback_steps (`int`, *optional*, defaults to 1): |
| The frequency at which the `callback` function is called. If not specified, the callback is called at |
| every step. |
| cross_attention_kwargs (`dict`, *optional*): |
| A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
| [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
| guidance_rescale (`float`, *optional*, defaults to 0.7): |
| Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are |
| Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when |
| using zero terminal SNR. |
| |
| Examples: |
| |
| Returns: |
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, |
| otherwise a `tuple` is returned where the first element is a list with the generated images and the |
| second element is a list of `bool`s indicating whether the corresponding generated image contains |
| "not-safe-for-work" (nsfw) content. |
| """ |
| |
| height = height or self.unet.config.sample_size * self.vae_scale_factor |
| width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
| |
| self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) |
|
|
| |
| if prompt is not None and isinstance(prompt, str): |
| batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| device = self._execution_device |
| |
| |
| |
| do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
| |
| text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
| prompt, |
| device, |
| num_images_per_prompt, |
| do_classifier_free_guidance, |
| negative_prompt, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| lora_scale=text_encoder_lora_scale, |
| ) |
| |
| |
| |
| if do_classifier_free_guidance: |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
|
|
| |
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
| timesteps = self.scheduler.timesteps |
|
|
| |
| num_channels_latents = self.unet.config.in_channels |
| latents = self.prepare_latents( |
| batch_size * num_images_per_prompt, |
| num_channels_latents, |
| height, |
| width, |
| prompt_embeds.dtype, |
| device, |
| generator, |
| latents, |
| ) |
|
|
| |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
| |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| with self.progress_bar(total=len(timesteps)) as progress_bar: |
| for i, t in enumerate(timesteps): |
| |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
| |
| noise_pred = self.unet( |
| latent_model_input, |
| t, |
| encoder_hidden_states=prompt_embeds, |
| cross_attention_kwargs=cross_attention_kwargs, |
| return_dict=False, |
| )[0] |
|
|
| |
| if do_classifier_free_guidance: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
| if do_classifier_free_guidance and guidance_rescale > 0.0: |
| |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) |
|
|
| |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
| |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| progress_bar.update() |
| if callback is not None and i % callback_steps == 0: |
| callback(i, t, latents) |
|
|
| if not output_type == "latent": |
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
| image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| else: |
| image = latents |
| has_nsfw_concept = None |
|
|
| if has_nsfw_concept is None: |
| do_denormalize = [True] * image.shape[0] |
| else: |
| do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
|
|
| image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
|
|
| |
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return (image, has_nsfw_concept) |
|
|
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
|