from transformers import PretrainedConfig, PreTrainedModel from torch import nn, tensor, concat from diffusers.models.embeddings import get_timestep_embedding import torch class T5DiffusionXLTextEncoderMergerConfig(PretrainedConfig): def __init__(self, num_layers: int = 4, dim_timestep_embeds: int = 16, seq_len: int = 77, channels_sdxl: int = 2048, channels_t5: int = 4096, channels_pooled: int = 1280, **kwargs): super().__init__(**kwargs) self.num_layers = num_layers self.dim_timestep_embeds = dim_timestep_embeds self.seq_len = seq_len self.channels_sdxl = channels_sdxl self.channels_t5 = channels_t5 self.channels_pooled = channels_pooled class T5DiffusionXLTextEncoderMerger(PreTrainedModel, nn.Module): def __init__(self, config: T5DiffusionXLTextEncoderMergerConfig): super().__init__(config) self._last_timestep = 0 channels_concat = config.channels_sdxl + config.channels_t5 self.block_forward1 = nn.Sequential( nn.Linear(channels_concat, channels_concat), nn.LayerNorm([config.seq_len, channels_concat], elementwise_affine=False)) layers = [] for _ in range(config.num_layers - 1): layers.append(nn.Linear(channels_concat, channels_concat)) layers.append(nn.SiLU()) layers.append(nn.Linear(channels_concat, config.channels_sdxl)) layers.append(nn.Tanh()) self.block_forward2 = nn.Sequential(*layers) self.block_modulate_by_pooled = nn.Sequential( nn.Linear(config.channels_pooled, 512, bias=False), nn.SiLU(), nn.Linear(512, config.seq_len * (channels_concat * 2 + config.channels_sdxl), bias=False)) self.block_modulate_by_timestep = nn.Sequential( nn.Linear(config.dim_timestep_embeds, 512, bias=False), nn.SiLU(), nn.Linear(512, config.seq_len * (channels_concat * 2 + config.channels_sdxl), bias=False)) def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.normal_(0, 0.1) if module.bias is not None: module.bias.zero_() def forward(self, embeds_t5, embeds_sdxl, pooled_embeds_sdxl): batch_size = embeds_sdxl.size(0) assert batch_size == embeds_sdxl.size(0) == pooled_embeds_sdxl.size(0) channels_sdxl = self.config.channels_sdxl channels_concat = self.config.channels_t5 + channels_sdxl seq_len = self.config.seq_len timestep_embeds = get_timestep_embedding( tensor([self._last_timestep]), embedding_dim=self.config.dim_timestep_embeds).repeat( batch_size, 1) modulation = self.block_modulate_by_timestep( timestep_embeds) + self.block_modulate_by_pooled(pooled_embeds_sdxl) gamma, beta, zeta = [ m.view(batch_size, seq_len, -1) for m in modulation.split([ seq_len * channels_concat, seq_len * channels_concat, seq_len * channels_sdxl ], dim=1) ] output = (gamma + 1) * self.block_forward1( concat((embeds_t5, embeds_sdxl), dim=2)) + beta output = (zeta + 1) * self.block_forward2(output) output += embeds_sdxl return {"output": output} def set_timestep(self, timestep: int): self._last_timestep = timestep