t5diffusionxl / text_encoder_merger.py
Haowei Chen
Sketch of text encoder merger
1fdf2ca
from transformers import PretrainedConfig, PreTrainedModel
from torch import nn, tensor, concat
from diffusers.models.embeddings import get_timestep_embedding
import torch
class T5DiffusionXLTextEncoderMergerConfig(PretrainedConfig):
def __init__(self,
num_layers: int = 4,
dim_timestep_embeds: int = 16,
seq_len: int = 77,
channels_sdxl: int = 2048,
channels_t5: int = 4096,
channels_pooled: int = 1280,
**kwargs):
super().__init__(**kwargs)
self.num_layers = num_layers
self.dim_timestep_embeds = dim_timestep_embeds
self.seq_len = seq_len
self.channels_sdxl = channels_sdxl
self.channels_t5 = channels_t5
self.channels_pooled = channels_pooled
class T5DiffusionXLTextEncoderMerger(PreTrainedModel, nn.Module):
def __init__(self, config: T5DiffusionXLTextEncoderMergerConfig):
super().__init__(config)
self._last_timestep = 0
channels_concat = config.channels_sdxl + config.channels_t5
self.block_forward1 = nn.Sequential(
nn.Linear(channels_concat, channels_concat),
nn.LayerNorm([config.seq_len, channels_concat],
elementwise_affine=False))
layers = []
for _ in range(config.num_layers - 1):
layers.append(nn.Linear(channels_concat, channels_concat))
layers.append(nn.SiLU())
layers.append(nn.Linear(channels_concat, config.channels_sdxl))
layers.append(nn.Tanh())
self.block_forward2 = nn.Sequential(*layers)
self.block_modulate_by_pooled = nn.Sequential(
nn.Linear(config.channels_pooled, 512, bias=False), nn.SiLU(),
nn.Linear(512,
config.seq_len *
(channels_concat * 2 + config.channels_sdxl),
bias=False))
self.block_modulate_by_timestep = nn.Sequential(
nn.Linear(config.dim_timestep_embeds, 512, bias=False), nn.SiLU(),
nn.Linear(512,
config.seq_len *
(channels_concat * 2 + config.channels_sdxl),
bias=False))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.normal_(0, 0.1)
if module.bias is not None:
module.bias.zero_()
def forward(self, embeds_t5, embeds_sdxl, pooled_embeds_sdxl):
batch_size = embeds_sdxl.size(0)
assert batch_size == embeds_sdxl.size(0) == pooled_embeds_sdxl.size(0)
channels_sdxl = self.config.channels_sdxl
channels_concat = self.config.channels_t5 + channels_sdxl
seq_len = self.config.seq_len
timestep_embeds = get_timestep_embedding(
tensor([self._last_timestep]),
embedding_dim=self.config.dim_timestep_embeds).repeat(
batch_size, 1)
modulation = self.block_modulate_by_timestep(
timestep_embeds) + self.block_modulate_by_pooled(pooled_embeds_sdxl)
gamma, beta, zeta = [
m.view(batch_size, seq_len, -1) for m in modulation.split([
seq_len * channels_concat, seq_len * channels_concat, seq_len *
channels_sdxl
],
dim=1)
]
output = (gamma + 1) * self.block_forward1(
concat((embeds_t5, embeds_sdxl), dim=2)) + beta
output = (zeta + 1) * self.block_forward2(output)
output += embeds_sdxl
return {"output": output}
def set_timestep(self, timestep: int):
self._last_timestep = timestep