| from transformers import PretrainedConfig, PreTrainedModel |
| from torch import nn, tensor, concat |
| from diffusers.models.embeddings import get_timestep_embedding |
| import torch |
|
|
| class T5DiffusionXLTextEncoderMergerConfig(PretrainedConfig): |
|
|
| def __init__(self, |
| num_layers: int = 4, |
| dim_timestep_embeds: int = 16, |
| seq_len: int = 77, |
| channels_sdxl: int = 2048, |
| channels_t5: int = 4096, |
| channels_pooled: int = 1280, |
| **kwargs): |
| super().__init__(**kwargs) |
| self.num_layers = num_layers |
| self.dim_timestep_embeds = dim_timestep_embeds |
| self.seq_len = seq_len |
| self.channels_sdxl = channels_sdxl |
| self.channels_t5 = channels_t5 |
| self.channels_pooled = channels_pooled |
|
|
|
|
| class T5DiffusionXLTextEncoderMerger(PreTrainedModel, nn.Module): |
|
|
| def __init__(self, config: T5DiffusionXLTextEncoderMergerConfig): |
| super().__init__(config) |
| self._last_timestep = 0 |
| channels_concat = config.channels_sdxl + config.channels_t5 |
| self.block_forward1 = nn.Sequential( |
| nn.Linear(channels_concat, channels_concat), |
| nn.LayerNorm([config.seq_len, channels_concat], |
| elementwise_affine=False)) |
|
|
| layers = [] |
| for _ in range(config.num_layers - 1): |
| layers.append(nn.Linear(channels_concat, channels_concat)) |
| layers.append(nn.SiLU()) |
| layers.append(nn.Linear(channels_concat, config.channels_sdxl)) |
| layers.append(nn.Tanh()) |
| self.block_forward2 = nn.Sequential(*layers) |
|
|
| self.block_modulate_by_pooled = nn.Sequential( |
| nn.Linear(config.channels_pooled, 512, bias=False), nn.SiLU(), |
| nn.Linear(512, |
| config.seq_len * |
| (channels_concat * 2 + config.channels_sdxl), |
| bias=False)) |
|
|
| self.block_modulate_by_timestep = nn.Sequential( |
| nn.Linear(config.dim_timestep_embeds, 512, bias=False), nn.SiLU(), |
| nn.Linear(512, |
| config.seq_len * |
| (channels_concat * 2 + config.channels_sdxl), |
| bias=False)) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.normal_(0, 0.1) |
| if module.bias is not None: |
| module.bias.zero_() |
|
|
| def forward(self, embeds_t5, embeds_sdxl, pooled_embeds_sdxl): |
| batch_size = embeds_sdxl.size(0) |
| assert batch_size == embeds_sdxl.size(0) == pooled_embeds_sdxl.size(0) |
| channels_sdxl = self.config.channels_sdxl |
| channels_concat = self.config.channels_t5 + channels_sdxl |
| seq_len = self.config.seq_len |
| timestep_embeds = get_timestep_embedding( |
| tensor([self._last_timestep]), |
| embedding_dim=self.config.dim_timestep_embeds).repeat( |
| batch_size, 1) |
| modulation = self.block_modulate_by_timestep( |
| timestep_embeds) + self.block_modulate_by_pooled(pooled_embeds_sdxl) |
| gamma, beta, zeta = [ |
| m.view(batch_size, seq_len, -1) for m in modulation.split([ |
| seq_len * channels_concat, seq_len * channels_concat, seq_len * |
| channels_sdxl |
| ], |
| dim=1) |
| ] |
| output = (gamma + 1) * self.block_forward1( |
| concat((embeds_t5, embeds_sdxl), dim=2)) + beta |
| output = (zeta + 1) * self.block_forward2(output) |
| output += embeds_sdxl |
| return {"output": output} |
|
|
| def set_timestep(self, timestep: int): |
| self._last_timestep = timestep |
|
|