| """ |
| Definition of Infinity transformer model. |
| """ |
|
|
| import math |
| import random |
| import time |
| from contextlib import nullcontext |
| from functools import partial |
| from typing import List, Optional, Tuple, Union, Dict, Any |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from timm.models import register_model |
| from torch.utils.checkpoint import checkpoint |
| from PIL import Image |
| import numpy as np |
| from torch.nn.attention.flex_attention import flex_attention |
|
|
| import infinity.utils.dist as dist |
|
|
| from infinity.utils.dist import for_visualize |
| from infinity.models.basic import flash_attn_func, flash_fused_op_installed, AdaLNBeforeHead, CrossAttnBlock, SelfAttnBlock, CrossAttention, FastRMSNorm, precompute_rope2d_freqs_grid |
| from infinity.utils import misc |
| from infinity.models.flex_attn import FlexAttn |
| from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates |
| try: |
| from infinity.models.fused_op import fused_ada_layer_norm, fused_ada_rms_norm |
| except: |
| fused_ada_layer_norm, fused_ada_rms_norm = None, None |
| |
| import pdb |
|
|
| class MultiInpIdentity(nn.Module): |
| def forward(self, x, *args, **kwargs): |
| return x |
|
|
|
|
| class TextAttentivePool(nn.Module): |
| def __init__(self, Ct5: int, D: int): |
| super().__init__() |
| self.Ct5, self.D = Ct5, D |
| if D > 4096: |
| self.head_dim = 64 |
| else: |
| self.head_dim = 128 |
|
|
| self.num_heads = Ct5 // self.head_dim |
| self.ca = CrossAttention(for_attn_pool=True, embed_dim=self.D, kv_dim=Ct5, num_heads=self.num_heads) |
| def forward(self, ca_kv): |
| return self.ca(None, ca_kv).squeeze(1) |
|
|
| class SharedAdaLin(nn.Linear): |
| def forward(self, cond_BD): |
| C = self.weight.shape[0] // 6 |
| return super().forward(cond_BD).reshape(-1, 1, 6, C) |
|
|
|
|
| class MultipleLayers(nn.Module): |
| def __init__(self, ls, num_blocks_in_a_chunk, index): |
| super().__init__() |
| self.module = nn.ModuleList() |
| for i in range(index, index+num_blocks_in_a_chunk): |
| self.module.append(ls[i]) |
|
|
| def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None): |
| h = x |
| for m in self.module: |
| if checkpointing_full_block: |
| h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False) |
| else: |
| h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid) |
| return h |
| |
| def forward_fsdp(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None): |
| h = x |
| for m in self.module.module: |
| if checkpointing_full_block: |
| h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False) |
| else: |
| h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid) |
| return h |
|
|
| class STGumbelArgmax(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, logits, tau=1.0): |
| |
| y_hard = F.gumbel_softmax(logits, tau=tau, hard=True) |
| |
| ctx.save_for_backward(logits) |
| ctx.tau = tau |
| return y_hard |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| |
| logits, = ctx.saved_tensors |
| tau = ctx.tau |
| |
| soft_grad = F.gumbel_softmax(logits, tau=tau, hard=False) |
| grad_input = grad_output * soft_grad |
| return grad_input, None |
| |
| def GumbelArgmax(logits,tau): |
| |
| U = torch.rand(logits.shape).to(logits.device) |
| eps = 1e-20 |
| |
| gumbel_noise = -torch.log(-torch.log(U + eps) + eps) |
| perturbed_logits = (logits + gumbel_noise) / tau |
|
|
| y_soft = F.softmax(perturbed_logits, dim=-1) |
| index = y_soft.max(dim=-1, keepdim=True)[1] |
| y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0) |
| |
| return y_hard - y_soft.detach() + y_soft |
|
|
| class Infinity(nn.Module): |
| def __init__( |
| self, vae_local, |
| text_channels=0, text_maxlen=0, |
| selecting_idx=None, |
| embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., |
| drop_rate=0., drop_path_rate=0., |
| norm_eps=1e-6, rms_norm=False, |
| shared_aln=False, head_aln=True, |
| cond_drop_rate=0.1, |
| rand_uncond=False, |
| cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, |
| raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
| head_depth=1, |
| top_p=0.0, top_k=0.0, |
| customized_flash_attn=False, fused_mlp=False, fused_norm=False, |
| block_chunks=1, |
| checkpointing=None, |
| pad_to_multiplier=0, |
| use_flex_attn=False, |
| batch_size=2, |
| add_lvl_embeding_only_first_block=1, |
| use_bit_label=1, |
| rope2d_each_sa_layer=0, |
| rope2d_normalized_by_hw=0, |
| pn=None, |
| train_h_div_w_list=None, |
| video_frames=1, |
| always_training_scales=20, |
| apply_spatial_patchify = 0, |
| inference_mode=False, |
| ): |
| |
| self.C = embed_dim |
| self.inference_mode = inference_mode |
| self.apply_spatial_patchify = apply_spatial_patchify |
| if self.apply_spatial_patchify: |
| self.d_vae = vae_local.embed_dim * 4 |
| else: |
| self.d_vae = vae_local.embed_dim |
| self.use_bit_label = use_bit_label |
| self.codebook_dim = self.d_vae |
| self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size |
| self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None |
| self.Ct5 = text_channels |
| self.depth = depth |
| self.num_heads = num_heads |
| self.batch_size = batch_size |
| self.mlp_ratio = mlp_ratio |
| self.cond_drop_rate = cond_drop_rate |
| self.norm_eps = norm_eps |
| self.prog_si = -1 |
| self.pn = pn |
| self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates |
| self.video_frames = video_frames |
| self.always_training_scales = always_training_scales |
|
|
| assert add_lvl_embeding_only_first_block in [0,1] |
| self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block |
| assert rope2d_each_sa_layer in [0,1] |
| self.rope2d_each_sa_layer = rope2d_each_sa_layer |
| self.rope2d_normalized_by_hw = rope2d_normalized_by_hw |
| print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ |
| self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') |
| head_up_method = '' |
| word_patch_size = 1 if head_up_method in {'', 'no'} else 2 |
| if word_patch_size > 1: |
| assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' |
| |
| self.checkpointing = checkpointing |
| self.pad_to_multiplier = max(1, pad_to_multiplier) |
| |
| customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) |
| self.customized_flash_attn = customized_flash_attn and customized_kernel_installed |
| if customized_flash_attn and not customized_kernel_installed: |
| import inspect, warnings |
| file_path = inspect.getsourcefile(flash_attn_func) |
| line_number = inspect.getsourcelines(flash_attn_func)[1] |
| info = ( |
| f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' |
| f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' |
| f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' |
| ) |
| warnings.warn(info, ImportWarning) |
| print(info, flush=True) |
| |
| self.raw_scale_schedule = raw_scale_schedule |
| self.first_l = 1 |
| |
| self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) |
| if self.top_p < 1e-5: self.top_p = 0 |
| if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 |
| |
| t = torch.zeros(dist.get_world_size(), device=dist.get_device()) |
| t[dist.get_rank()] = float(flash_fused_op_installed) |
| dist.barrier() |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' |
| |
| super().__init__() |
| self.rng = torch.Generator(device=dist.get_device()) |
| self.maybe_record_function = nullcontext |
| self.text_maxlen = text_maxlen |
| self.t2i = text_channels != 0 |
| |
| |
| init_std = math.sqrt(1 / self.C / 3) |
| self.norm0_cond = nn.Identity() |
| if self.t2i: |
| self.selecting_idx = None |
| self.num_classes = 0 |
| self.D = self.C |
| |
| cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) |
| rng = torch.Generator(device='cpu') |
| rng.manual_seed(0) |
| torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) |
| cfg_uncond /= self.Ct5 ** 0.5 |
| if rand_uncond: |
| self.register_buffer('cfg_uncond', cfg_uncond) |
| else: |
| self.cfg_uncond = nn.Parameter(cfg_uncond) |
| |
| self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) |
| self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) |
| self.text_proj_for_ca = nn.Sequential( |
| nn.Linear(self.Ct5, self.D), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(self.D, self.D), |
| ) |
| else: |
| if selecting_idx is None: |
| num_classes = 1000 |
| print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') |
| selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) |
| self.selecting_idx = selecting_idx |
| self.num_classes = selecting_idx.shape[-1] |
| self.D = self.C |
| self.class_emb = nn.Embedding(self.num_classes + 1, self.C) |
| nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) |
| |
| self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) |
| nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) |
| if self.rope2d_each_sa_layer: |
| rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) |
| self.rope2d_freqs_grid = rope2d_freqs_grid |
| else: |
| raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') |
| self.lvl_embed = nn.Embedding(15, self.C) |
| nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) |
| |
| |
| norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) |
| self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() |
| self.word_embed = nn.Linear(self.d_vae, self.C) |
| |
| |
| self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() |
| |
| |
| if fused_norm: |
| fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm |
| if fused_norm_func is not None: |
| B = 2 |
| x = torch.randn(B, 1, self.C).requires_grad_(True) |
| scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| |
| del B, x, scale, shift |
| else: |
| fused_norm_func = None |
| |
| |
| self.use_flex_attn = use_flex_attn |
| self.attn_fn_compile_dict = {} |
| self.batch_size = batch_size |
| if self.use_flex_attn: |
| self.attn_fn_compile_dict = self.compile_flex_attn() |
|
|
| self.drop_path_rate = drop_path_rate |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| self.unregistered_blocks = [] |
| for block_idx in range(depth): |
| block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( |
| embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, |
| num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, |
| swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, |
| checkpointing_sa_only=self.checkpointing == 'self-attn', |
| use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
| ) |
| self.unregistered_blocks.append(block) |
| |
| |
| V = self.V |
| if head_aln: |
| self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) |
| self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| else: |
| self.head_nm = MultiInpIdentity() |
| self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| |
| self.num_block_chunks = block_chunks or 1 |
| self.num_blocks_in_a_chunk = depth // block_chunks |
| print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") |
| assert self.num_blocks_in_a_chunk * block_chunks == depth |
| if self.num_block_chunks == 1: |
| self.blocks = nn.ModuleList(self.unregistered_blocks) |
| else: |
| self.block_chunks = nn.ModuleList() |
| for i in range(self.num_block_chunks): |
| self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) |
| print( |
| f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' |
| f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' |
| f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', |
| end='\n\n', flush=True |
| ) |
| |
|
|
| def compile_flex_attn(self): |
| attn_fn_compile_dict = {} |
| for h_div_w in self.train_h_div_w_list: |
| h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] |
| full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] |
| if self.inference_mode: |
| apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) |
| mask_type = "infinity_infer_mask_with_kv_cache" |
| auto_padding = True |
| else: |
| mask_type = 'var' |
| auto_padding = False |
| apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] |
| for scales_num in apply_flex_attn_scales: |
| print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') |
| scale_schedule = full_scale_schedule[:scales_num] |
| scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L, |
| auto_padding=auto_padding) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
|
|
| if self.video_frames > 1: |
| scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
| return attn_fn_compile_dict |
| |
| def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): |
| """ |
| :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) |
| :param cond_BD: shaped (B or batch_size, D or cond_dim) |
| :param tau: temperature |
| :return: logits, shaped (B or batch_size, V or vocabulary_size) |
| """ |
| with torch.amp.autocast('cuda', enabled=False): |
| return self.head(self.head_nm(h.float(), cond_BD.float())) |
|
|
| def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): |
| bs, seq_len, c = feature.shape |
| patch_t, patch_h, patch_w = scale_schedule[scale_ind] |
| t_mul_h_mul_w = patch_t * patch_h * patch_w |
| assert t_mul_h_mul_w + need_to_pad == seq_len |
| feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) |
| return feature |
| |
| def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): |
| ptr = 0 |
| x_BLC_list = [] |
| for scale_ind, patch_t_h_w in enumerate(scale_schedule): |
| scale_seq_len = np.array(patch_t_h_w).prod() |
| x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] |
| ptr += scale_seq_len |
| x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) |
| x_BLC_list.append(x_BLC_this_scale) |
| assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' |
| x_BLC_list.append(x_BLC[:,ptr:]) |
| x_BLC = torch.cat(x_BLC_list, dim=1) |
| return x_BLC |
|
|
| def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
|
|
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| |
| |
| return self.get_logits(x_BLC[:, :l_end], cond_BD) |
|
|
| @torch.no_grad() |
| def autoregressive_infer_cfg( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
|
|
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| for si, pn in enumerate(scale_schedule): |
| cfg = cfg_list[si] |
| if si >= trunk_scale: |
| break |
| cur_L += np.array(pn).prod() |
|
|
| need_to_pad = 0 |
| attn_fn = None |
| if self.use_flex_attn: |
| |
| |
| |
| attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) |
|
|
| |
| layer_idx = 0 |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| for m in b.module: |
| last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] |
| last_stage = torch.cat((last_stage, last_stage), 0) |
| layer_idx += 1 |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| |
| logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) |
| |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| if vae_type != 0: |
| assert returns_vemb |
| if si < gt_leak: |
| idx_Bld = gt_ls_Bl[si] |
| else: |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
| idx_Bld_list.append(idx_Bld) |
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) |
| last_stage = last_stage.squeeze(-3) |
| if self.apply_spatial_patchify: |
| last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) |
| last_stage = last_stage.reshape(*last_stage.shape[:2], -1) |
| last_stage = torch.permute(last_stage, [0,2,1]) |
| else: |
| summed_codes += codes |
| else: |
| if si < gt_leak: |
| idx_Bl = gt_ls_Bl[si] |
| h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() |
|
|
| |
| h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) |
| ret.append(h_BChw if returns_vemb != 0 else idx_Bl) |
| idx_Bl_list.append(idx_Bl) |
| if si != num_stages_minus_1: |
| accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) |
| |
| if si != num_stages_minus_1: |
| last_stage = self.word_embed(self.norm0_ve(last_stage)) |
| last_stage = last_stage.repeat(bs//B, 1, 1) |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
|
|
| if not ret_img: |
| return ret, idx_Bl_list, [] |
| |
| if vae_type != 0: |
| img = vae.decode(summed_codes.squeeze(-3)) |
| else: |
| img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| return ret, idx_Bl_list, img |
| |
| @torch.no_grad() |
| def autoregressive_infer_cfg_w_lq_token( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| x_BLC_wo_prefix_lq=None, |
| gt_BL_list=[], |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
| |
| |
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
| |
| x_BLC_wo_prefix_lq = x_BLC_wo_prefix_lq.float() |
| x_BLC_wo_prefix_lq = x_BLC_wo_prefix_lq.expand(bs,-1,-1) |
| x_BLC_lq = torch.cat((last_stage, self.word_embed(self.norm0_ve(x_BLC_wo_prefix_lq))), dim=1) |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| logits_BlV_list = [] |
| for si, pn in enumerate(scale_schedule): |
| cfg = cfg_list[si] |
| if si >= trunk_scale: |
| break |
| cur_L += np.array(pn).prod() |
| |
| |
| if si < 10: |
| last_stage = x_BLC_lq_list[si] |
| |
|
|
| need_to_pad = 0 |
| attn_fn = None |
| if self.use_flex_attn: |
| |
| |
| |
| attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) |
|
|
| |
| layer_idx = 0 |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| for m in b.module: |
| last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] |
| last_stage = torch.cat((last_stage, last_stage), 0) |
| layer_idx += 1 |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| |
| logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) |
| |
| logits_BlV_list.append(logits_BlV) |
| |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| |
| |
| if si < len(gt_BL_list): |
| idx_Bld = gt_BL_list[si] |
| |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| if vae_type != 0: |
| assert returns_vemb |
| if si < gt_leak: |
| idx_Bld = gt_ls_Bl[si] |
| else: |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
| idx_Bld_list.append(idx_Bld) |
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) |
| last_stage = last_stage.squeeze(-3) |
| if self.apply_spatial_patchify: |
| last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) |
| last_stage = last_stage.reshape(*last_stage.shape[:2], -1) |
| last_stage = torch.permute(last_stage, [0,2,1]) |
| else: |
| summed_codes += codes |
| else: |
| if si < gt_leak: |
| idx_Bl = gt_ls_Bl[si] |
| h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() |
|
|
| |
| h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) |
| ret.append(h_BChw if returns_vemb != 0 else idx_Bl) |
| idx_Bl_list.append(idx_Bl) |
| if si != num_stages_minus_1: |
| accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) |
| |
| if si != num_stages_minus_1: |
| last_stage = self.word_embed(self.norm0_ve(last_stage)) |
| last_stage = last_stage.repeat(bs//B, 1, 1) |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
|
|
| if not ret_img: |
| return ret, idx_Bl_list, [] |
| |
| if vae_type != 0: |
| img = vae.decode(summed_codes.squeeze(-3)) |
| else: |
| img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| |
| logits_BlV_all = torch.cat(logits_BlV_list,dim = 1) |
| return ret, idx_Bl_list, img, logits_BlV_all |
| |
| def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) |
| |
| B = logits_BlV_all.shape[0] |
| |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| summed_codes = 0 |
| num_stages_minus_1 = len(scale_schedule)-1 |
| |
| for si,logits_BlV in enumerate(logits_BlV_list): |
| pn= scale_schedule[si] |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| idx_Bld = GumbelArgmax(logits_BlV, 0.5) |
| tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) |
| tmp_tensor[:,:,1:]=1 |
| idx_Bld = idx_Bld * tmp_tensor |
| idx_Bld = idx_Bld.sum(dim=-1) |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
|
|
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| else: |
| summed_codes += codes |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| img = vae.decode(summed_codes.squeeze(-3)) |
| |
| |
| return img |
| |
| def forward_teacher(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
|
|
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
|
|
| |
| return self.get_logits(x_BLC[:, :l_end], cond_BD) |
| |
| |
| @for_visualize |
| def vis_key_params(self, ep): |
| return |
| |
| def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): |
| for k in state_dict: |
| if 'cfg_uncond' in k: |
| old, new = state_dict[k], self.cfg_uncond.data |
| min_tlen = min(old.shape[0], new.shape[0]) |
| if min_tlen == old.shape[0]: |
| state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) |
| else: |
| state_dict[k] = old[:min_tlen] |
| |
| for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): |
| state_dict.pop(buf_name, None) |
| if hasattr(self, buf_name): |
| state_dict[buf_name] = getattr(self, buf_name) |
| |
| return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) |
| |
| def special_init( |
| self, |
| aln_init: float, |
| aln_gamma_init: float, |
| scale_head: float, |
| scale_proj: int, |
| ): |
| |
| if isinstance(self.head_nm, AdaLNBeforeHead): |
| self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) |
| if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: |
| self.head_nm.ada_lin[-1].bias.data.zero_() |
| |
| |
| if scale_head >= 0: |
| if isinstance(self.head, nn.Linear): |
| self.head.weight.data.mul_(scale_head) |
| self.head.bias.data.zero_() |
| elif isinstance(self.head, nn.Sequential): |
| self.head[-1].weight.data.mul_(scale_head) |
| self.head[-1].bias.data.zero_() |
| |
| depth = len(self.unregistered_blocks) |
| for block_idx, sab in enumerate(self.unregistered_blocks): |
| sab: Union[SelfAttnBlock, CrossAttnBlock] |
| |
| scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) |
| if scale_proj == 1: |
| if self.t2i: |
| sab.sa.proj.weight.data.mul_(scale) |
| sab.ca.proj.weight.data.mul_(scale) |
| else: |
| sab.attn.proj.weight.data.mul_(scale) |
| sab.ffn.fc2.weight.data.mul_(scale) |
| |
| |
| |
| |
| |
| if hasattr(sab, 'ada_lin'): |
| lin = sab.ada_lin[-1] |
| lin.weight.data[:2*self.C].mul_(aln_gamma_init) |
| lin.weight.data[2*self.C:].mul_(aln_init) |
| if hasattr(lin, 'bias') and lin.bias is not None: |
| lin.bias.data.zero_() |
| elif hasattr(sab, 'ada_gss'): |
| sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) |
| sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) |
| |
| def extra_repr(self): |
| return f'drop_path_rate={self.drop_path_rate}' |
| |
| def get_layer_id_and_scale_exp(self, para_name: str): |
| raise NotImplementedError |
|
|
| class BInfinity(nn.Module): |
| def __init__( |
| self, vae_local, |
| text_channels=0, text_maxlen=0, |
| selecting_idx=None, |
| embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., |
| drop_rate=0., drop_path_rate=0., |
| norm_eps=1e-6, rms_norm=False, |
| shared_aln=False, head_aln=True, |
| cond_drop_rate=0.1, |
| rand_uncond=False, |
| cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, |
| raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
| head_depth=1, |
| top_p=0.0, top_k=0.0, |
| customized_flash_attn=False, fused_mlp=False, fused_norm=False, |
| block_chunks=1, |
| checkpointing=None, |
| pad_to_multiplier=0, |
| use_flex_attn=False, |
| batch_size=2, |
| add_lvl_embeding_only_first_block=1, |
| use_bit_label=1, |
| rope2d_each_sa_layer=0, |
| rope2d_normalized_by_hw=0, |
| pn=None, |
| train_h_div_w_list=None, |
| video_frames=1, |
| always_training_scales=20, |
| apply_spatial_patchify = 0, |
| inference_mode=False, |
| ): |
| |
| self.C = embed_dim |
| self.inference_mode = inference_mode |
| self.apply_spatial_patchify = apply_spatial_patchify |
| if self.apply_spatial_patchify: |
| self.d_vae = vae_local.embed_dim * 4 |
| else: |
| self.d_vae = vae_local.embed_dim |
| self.use_bit_label = use_bit_label |
| self.codebook_dim = self.d_vae |
| self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size |
| self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None |
| self.Ct5 = text_channels |
| self.depth = depth |
| self.num_heads = num_heads |
| self.batch_size = batch_size |
| self.mlp_ratio = mlp_ratio |
| self.cond_drop_rate = cond_drop_rate |
| self.norm_eps = norm_eps |
| self.prog_si = -1 |
| self.pn = pn |
| self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates |
| self.video_frames = video_frames |
| self.always_training_scales = always_training_scales |
|
|
| assert add_lvl_embeding_only_first_block in [0,1] |
| self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block |
| assert rope2d_each_sa_layer in [0,1] |
| self.rope2d_each_sa_layer = rope2d_each_sa_layer |
| self.rope2d_normalized_by_hw = rope2d_normalized_by_hw |
| print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ |
| self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') |
| head_up_method = '' |
| word_patch_size = 1 if head_up_method in {'', 'no'} else 2 |
| if word_patch_size > 1: |
| assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' |
| |
| self.checkpointing = checkpointing |
| self.pad_to_multiplier = max(1, pad_to_multiplier) |
| |
| customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) |
| self.customized_flash_attn = customized_flash_attn and customized_kernel_installed |
| if customized_flash_attn and not customized_kernel_installed: |
| import inspect, warnings |
| file_path = inspect.getsourcefile(flash_attn_func) |
| line_number = inspect.getsourcelines(flash_attn_func)[1] |
| info = ( |
| f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' |
| f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' |
| f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' |
| ) |
| warnings.warn(info, ImportWarning) |
| print(info, flush=True) |
| |
| self.raw_scale_schedule = raw_scale_schedule |
| self.first_l = 1 |
| |
| self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) |
| if self.top_p < 1e-5: self.top_p = 0 |
| if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 |
| |
| t = torch.zeros(dist.get_world_size(), device=dist.get_device()) |
| t[dist.get_rank()] = float(flash_fused_op_installed) |
| dist.barrier() |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' |
| |
| super().__init__() |
| self.rng = torch.Generator(device=dist.get_device()) |
| self.maybe_record_function = nullcontext |
| self.text_maxlen = text_maxlen |
| self.t2i = text_channels != 0 |
| |
| |
| init_std = math.sqrt(1 / self.C / 3) |
| self.norm0_cond = nn.Identity() |
| if self.t2i: |
| self.selecting_idx = None |
| self.num_classes = 0 |
| self.D = self.C |
| |
| cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) |
| rng = torch.Generator(device='cpu') |
| rng.manual_seed(0) |
| torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) |
| cfg_uncond /= self.Ct5 ** 0.5 |
| if rand_uncond: |
| self.register_buffer('cfg_uncond', cfg_uncond) |
| else: |
| self.cfg_uncond = nn.Parameter(cfg_uncond) |
| |
| self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) |
| self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) |
| self.text_proj_for_ca = nn.Sequential( |
| nn.Linear(self.Ct5, self.D), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(self.D, self.D), |
| ) |
| else: |
| if selecting_idx is None: |
| num_classes = 1000 |
| print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') |
| selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) |
| self.selecting_idx = selecting_idx |
| self.num_classes = selecting_idx.shape[-1] |
| self.D = self.C |
| self.class_emb = nn.Embedding(self.num_classes + 1, self.C) |
| nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) |
| |
| self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) |
| nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) |
| if self.rope2d_each_sa_layer: |
| rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) |
| self.rope2d_freqs_grid = rope2d_freqs_grid |
| else: |
| raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') |
| self.lvl_embed = nn.Embedding(15, self.C) |
| nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) |
| |
| |
| norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) |
| self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() |
| self.word_embed = nn.Linear(self.d_vae, self.C) |
| |
| |
| self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() |
| |
| |
| if fused_norm: |
| fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm |
| if fused_norm_func is not None: |
| B = 2 |
| x = torch.randn(B, 1, self.C).requires_grad_(True) |
| scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| |
| del B, x, scale, shift |
| else: |
| fused_norm_func = None |
| |
| |
| self.use_flex_attn = use_flex_attn |
| self.attn_fn_compile_dict = {} |
| self.batch_size = batch_size |
| if self.use_flex_attn: |
| self.attn_fn_compile_dict = self.compile_flex_attn() |
|
|
| self.drop_path_rate = drop_path_rate |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| self.unregistered_blocks = [] |
| for block_idx in range(depth): |
| block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( |
| embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, |
| num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, |
| swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, |
| checkpointing_sa_only=self.checkpointing == 'self-attn', |
| use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
| ) |
| self.unregistered_blocks.append(block) |
| |
| |
| V = self.V |
| if head_aln: |
| self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) |
| self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| else: |
| self.head_nm = MultiInpIdentity() |
| self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| |
| self.num_block_chunks = block_chunks or 1 |
| self.num_blocks_in_a_chunk = depth // block_chunks |
| print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") |
| assert self.num_blocks_in_a_chunk * block_chunks == depth |
| if self.num_block_chunks == 1: |
| self.blocks = nn.ModuleList(self.unregistered_blocks) |
| else: |
| self.block_chunks = nn.ModuleList() |
| for i in range(self.num_block_chunks): |
| self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) |
| print( |
| f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' |
| f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' |
| f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', |
| end='\n\n', flush=True |
| ) |
|
|
| def compile_flex_attn(self): |
| attn_fn_compile_dict = {} |
| for h_div_w in self.train_h_div_w_list: |
| h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] |
| full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] |
| if self.inference_mode: |
| apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) |
| mask_type = "infinity_infer_mask_with_kv_cache" |
| auto_padding = True |
| else: |
| mask_type = 'var' |
| auto_padding = False |
| apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] |
| for scales_num in apply_flex_attn_scales: |
| print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') |
| scale_schedule = full_scale_schedule[:scales_num] |
| scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L, |
| auto_padding=auto_padding) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
|
|
| if self.video_frames > 1: |
| scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
| return attn_fn_compile_dict |
| |
| def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): |
| """ |
| :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) |
| :param cond_BD: shaped (B or batch_size, D or cond_dim) |
| :param tau: temperature |
| :return: logits, shaped (B or batch_size, V or vocabulary_size) |
| """ |
| with torch.amp.autocast('cuda', enabled=False): |
| return self.head(self.head_nm(h.float(), cond_BD.float())) |
|
|
| def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): |
| bs, seq_len, c = feature.shape |
| patch_t, patch_h, patch_w = scale_schedule[scale_ind] |
| t_mul_h_mul_w = patch_t * patch_h * patch_w |
| assert t_mul_h_mul_w + need_to_pad == seq_len |
| feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) |
| return feature |
| |
| def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): |
| ptr = 0 |
| x_BLC_list = [] |
| for scale_ind, patch_t_h_w in enumerate(scale_schedule): |
| scale_seq_len = np.array(patch_t_h_w).prod() |
| x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] |
| ptr += scale_seq_len |
| x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) |
| x_BLC_list.append(x_BLC_this_scale) |
| assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' |
| x_BLC_list.append(x_BLC[:,ptr:]) |
| x_BLC = torch.cat(x_BLC_list, dim=1) |
| return x_BLC |
|
|
| def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| |
| all_scale_length = np.sum([np.prod(scale_schedule[j]) for j in range(len(scale_schedule))]) |
| last_scale_length = np.prod(scale_schedule[-1]) |
| if x_BLC.shape[1] == all_scale_length: |
| long_input = 0 |
| else: |
| assert x_BLC.shape[1] == all_scale_length + last_scale_length |
| long_input = 1 |
|
|
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
|
|
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| if long_input == 1: |
| scale_schedule_new = scale_schedule + [scale_schedule[-1]] |
| else: |
| scale_schedule_new = scale_schedule |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule_new, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule_new, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, None, None, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule_new , need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule_new, need_to_pad) |
| |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| |
| |
| |
|
|
| |
| return self.get_logits(x_BLC[:, :all_scale_length], cond_BD) |
| |
| def forward_teacher(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
|
|
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
|
|
| |
| return self.get_logits(x_BLC[:, :l_end], cond_BD) |
| |
| def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) |
| |
| B = logits_BlV_all.shape[0] |
| |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| summed_codes = 0 |
| num_stages_minus_1 = len(scale_schedule)-1 |
| |
| for si,logits_BlV in enumerate(logits_BlV_list): |
| pn= scale_schedule[si] |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| idx_Bld = GumbelArgmax(logits_BlV, 0.5) |
| tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) |
| tmp_tensor[:,:,1:]=1 |
| idx_Bld = idx_Bld * tmp_tensor |
| idx_Bld = idx_Bld.sum(dim=-1) |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
|
|
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| else: |
| summed_codes += codes |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| img = vae.decode(summed_codes.squeeze(-3)) |
| |
| |
| return img |
|
|
| @torch.no_grad() |
| def autoregressive_infer_cfg( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| x_BLC_wo_prefix_lq=None |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| |
| |
| logits_BlV = self.forward(label_B_or_BLT,x_BLC_wo_prefix_lq,scale_schedule) |
| |
| img = self.logits_to_img(logits_BlV_all=logits_BlV, |
| vae=vae, |
| scale_schedule=scale_schedule, |
| top_k=top_k, |
| top_p=top_p, |
| g_seed=g_seed) |
| |
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| |
| return None,None,img |
|
|
| |
| @for_visualize |
| def vis_key_params(self, ep): |
| return |
| |
| def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): |
| for k in state_dict: |
| if 'cfg_uncond' in k: |
| old, new = state_dict[k], self.cfg_uncond.data |
| min_tlen = min(old.shape[0], new.shape[0]) |
| if min_tlen == old.shape[0]: |
| state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) |
| else: |
| state_dict[k] = old[:min_tlen] |
| |
| for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): |
| state_dict.pop(buf_name, None) |
| if hasattr(self, buf_name): |
| state_dict[buf_name] = getattr(self, buf_name) |
| |
| return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) |
| |
| def special_init( |
| self, |
| aln_init: float, |
| aln_gamma_init: float, |
| scale_head: float, |
| scale_proj: int, |
| ): |
| |
| if isinstance(self.head_nm, AdaLNBeforeHead): |
| self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) |
| if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: |
| self.head_nm.ada_lin[-1].bias.data.zero_() |
| |
| |
| if scale_head >= 0: |
| if isinstance(self.head, nn.Linear): |
| self.head.weight.data.mul_(scale_head) |
| self.head.bias.data.zero_() |
| elif isinstance(self.head, nn.Sequential): |
| self.head[-1].weight.data.mul_(scale_head) |
| self.head[-1].bias.data.zero_() |
| |
| depth = len(self.unregistered_blocks) |
| for block_idx, sab in enumerate(self.unregistered_blocks): |
| sab: Union[SelfAttnBlock, CrossAttnBlock] |
| |
| scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) |
| if scale_proj == 1: |
| if self.t2i: |
| sab.sa.proj.weight.data.mul_(scale) |
| sab.ca.proj.weight.data.mul_(scale) |
| else: |
| sab.attn.proj.weight.data.mul_(scale) |
| sab.ffn.fc2.weight.data.mul_(scale) |
| |
| |
| |
| |
| |
| if hasattr(sab, 'ada_lin'): |
| lin = sab.ada_lin[-1] |
| lin.weight.data[:2*self.C].mul_(aln_gamma_init) |
| lin.weight.data[2*self.C:].mul_(aln_init) |
| if hasattr(lin, 'bias') and lin.bias is not None: |
| lin.bias.data.zero_() |
| elif hasattr(sab, 'ada_gss'): |
| sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) |
| sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) |
| |
| def extra_repr(self): |
| return f'drop_path_rate={self.drop_path_rate}' |
| |
| def get_layer_id_and_scale_exp(self, para_name: str): |
| raise NotImplementedError |
| |
| class AInfinity(nn.Module): |
| def __init__( |
| self, vae_local, |
| text_channels=0, text_maxlen=0, |
| selecting_idx=None, |
| embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., |
| drop_rate=0., drop_path_rate=0., |
| norm_eps=1e-6, rms_norm=False, |
| shared_aln=False, head_aln=True, |
| cond_drop_rate=0.1, |
| rand_uncond=False, |
| cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, |
| raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
| head_depth=1, |
| top_p=0.0, top_k=0.0, |
| customized_flash_attn=False, fused_mlp=False, fused_norm=False, |
| block_chunks=1, |
| checkpointing=None, |
| pad_to_multiplier=0, |
| use_flex_attn=False, |
| batch_size=2, |
| add_lvl_embeding_only_first_block=1, |
| use_bit_label=1, |
| rope2d_each_sa_layer=0, |
| rope2d_normalized_by_hw=0, |
| pn=None, |
| train_h_div_w_list=None, |
| video_frames=1, |
| always_training_scales=20, |
| apply_spatial_patchify = 0, |
| inference_mode=False, |
| ): |
| |
| self.C = embed_dim |
| self.inference_mode = inference_mode |
| self.apply_spatial_patchify = apply_spatial_patchify |
| if self.apply_spatial_patchify: |
| self.d_vae = vae_local.embed_dim * 4 |
| else: |
| self.d_vae = vae_local.embed_dim |
| self.use_bit_label = use_bit_label |
| self.codebook_dim = self.d_vae |
| self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size |
| self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None |
| self.Ct5 = text_channels |
| self.depth = depth |
| self.num_heads = num_heads |
| self.batch_size = batch_size |
| self.mlp_ratio = mlp_ratio |
| self.cond_drop_rate = cond_drop_rate |
| self.norm_eps = norm_eps |
| self.prog_si = -1 |
| self.pn = pn |
| self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates |
| self.video_frames = video_frames |
| self.always_training_scales = always_training_scales |
|
|
| assert add_lvl_embeding_only_first_block in [0,1] |
| self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block |
| assert rope2d_each_sa_layer in [0,1] |
| self.rope2d_each_sa_layer = rope2d_each_sa_layer |
| self.rope2d_normalized_by_hw = rope2d_normalized_by_hw |
| print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ |
| self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') |
| head_up_method = '' |
| word_patch_size = 1 if head_up_method in {'', 'no'} else 2 |
| if word_patch_size > 1: |
| assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' |
| |
| self.checkpointing = checkpointing |
| self.pad_to_multiplier = max(1, pad_to_multiplier) |
| |
| customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) |
| self.customized_flash_attn = customized_flash_attn and customized_kernel_installed |
| if customized_flash_attn and not customized_kernel_installed: |
| import inspect, warnings |
| file_path = inspect.getsourcefile(flash_attn_func) |
| line_number = inspect.getsourcelines(flash_attn_func)[1] |
| info = ( |
| f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' |
| f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' |
| f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' |
| ) |
| warnings.warn(info, ImportWarning) |
| print(info, flush=True) |
| |
| self.raw_scale_schedule = raw_scale_schedule |
| self.first_l = 1 |
| |
| self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) |
| if self.top_p < 1e-5: self.top_p = 0 |
| if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 |
| |
| t = torch.zeros(dist.get_world_size(), device=dist.get_device()) |
| t[dist.get_rank()] = float(flash_fused_op_installed) |
| dist.barrier() |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' |
| |
| super().__init__() |
| self.rng = torch.Generator(device=dist.get_device()) |
| self.maybe_record_function = nullcontext |
| self.text_maxlen = text_maxlen |
| self.t2i = text_channels != 0 |
| |
| |
| init_std = math.sqrt(1 / self.C / 3) |
| self.norm0_cond = nn.Identity() |
| if self.t2i: |
| self.selecting_idx = None |
| self.num_classes = 0 |
| self.D = self.C |
| |
| cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) |
| rng = torch.Generator(device='cpu') |
| rng.manual_seed(0) |
| torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) |
| cfg_uncond /= self.Ct5 ** 0.5 |
| if rand_uncond: |
| self.register_buffer('cfg_uncond', cfg_uncond) |
| else: |
| self.cfg_uncond = nn.Parameter(cfg_uncond) |
| |
| self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) |
| self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) |
| self.text_proj_for_ca = nn.Sequential( |
| nn.Linear(self.Ct5, self.D), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(self.D, self.D), |
| ) |
| else: |
| if selecting_idx is None: |
| num_classes = 1000 |
| print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') |
| selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) |
| self.selecting_idx = selecting_idx |
| self.num_classes = selecting_idx.shape[-1] |
| self.D = self.C |
| self.class_emb = nn.Embedding(self.num_classes + 1, self.C) |
| nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) |
| |
| self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) |
| nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) |
| if self.rope2d_each_sa_layer: |
| rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) |
| self.rope2d_freqs_grid = rope2d_freqs_grid |
| else: |
| raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') |
| self.lvl_embed = nn.Embedding(15, self.C) |
| nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) |
| |
| |
| norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) |
| self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() |
| self.word_embed = nn.Linear(self.d_vae, self.C) |
| |
| |
| self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() |
| |
| |
| if fused_norm: |
| fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm |
| if fused_norm_func is not None: |
| B = 2 |
| x = torch.randn(B, 1, self.C).requires_grad_(True) |
| scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| |
| del B, x, scale, shift |
| else: |
| fused_norm_func = None |
| |
| |
| self.use_flex_attn = use_flex_attn |
| self.attn_fn_compile_dict = {} |
| self.batch_size = batch_size |
| if self.use_flex_attn: |
| self.attn_fn_compile_dict = self.compile_flex_attn() |
|
|
| self.drop_path_rate = drop_path_rate |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| self.unregistered_blocks = [] |
| for block_idx in range(depth): |
| block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( |
| embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, |
| num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, |
| swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, |
| checkpointing_sa_only=self.checkpointing == 'self-attn', |
| use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
| ) |
| self.unregistered_blocks.append(block) |
| |
| |
| V = self.V |
| if head_aln: |
| self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) |
| self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| else: |
| self.head_nm = MultiInpIdentity() |
| self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| |
| self.num_block_chunks = block_chunks or 1 |
| self.num_blocks_in_a_chunk = depth // block_chunks |
| print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") |
| assert self.num_blocks_in_a_chunk * block_chunks == depth |
| if self.num_block_chunks == 1: |
| self.blocks = nn.ModuleList(self.unregistered_blocks) |
| else: |
| self.block_chunks = nn.ModuleList() |
| for i in range(self.num_block_chunks): |
| self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) |
| |
| self.lq_linear = nn.Linear(self.C, self.C) |
| print( |
| f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' |
| f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' |
| f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', |
| end='\n\n', flush=True |
| ) |
| |
|
|
| def compile_flex_attn(self): |
| attn_fn_compile_dict = {} |
| for h_div_w in self.train_h_div_w_list: |
| h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] |
| full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] |
| if self.inference_mode: |
| apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) |
| mask_type = "infinity_infer_mask_with_kv_cache" |
| auto_padding = True |
| else: |
| mask_type = 'var' |
| auto_padding = False |
| apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] |
| for scales_num in apply_flex_attn_scales: |
| print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') |
| scale_schedule = full_scale_schedule[:scales_num] |
| scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L, |
| auto_padding=auto_padding) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
|
|
| if self.video_frames > 1: |
| scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
| return attn_fn_compile_dict |
| |
| def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): |
| """ |
| :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) |
| :param cond_BD: shaped (B or batch_size, D or cond_dim) |
| :param tau: temperature |
| :return: logits, shaped (B or batch_size, V or vocabulary_size) |
| """ |
| with torch.amp.autocast('cuda', enabled=False): |
| return self.head(self.head_nm(h.float(), cond_BD.float())) |
|
|
| def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): |
| bs, seq_len, c = feature.shape |
| patch_t, patch_h, patch_w = scale_schedule[scale_ind] |
| t_mul_h_mul_w = patch_t * patch_h * patch_w |
| assert t_mul_h_mul_w + need_to_pad == seq_len |
| feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) |
| return feature |
| |
| def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): |
| ptr = 0 |
| x_BLC_list = [] |
| for scale_ind, patch_t_h_w in enumerate(scale_schedule): |
| scale_seq_len = np.array(patch_t_h_w).prod() |
| x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] |
| ptr += scale_seq_len |
| x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) |
| x_BLC_list.append(x_BLC_this_scale) |
| assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' |
| x_BLC_list.append(x_BLC[:,ptr:]) |
| x_BLC = torch.cat(x_BLC_list, dim=1) |
| return x_BLC |
|
|
| def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False, x_BLC_w_prefix_lq=None, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) |
| |
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
| |
| |
| x_BLC_lq = self.lq_linear(x_BLC_lq) |
| x_BLC = x_BLC + x_BLC_lq |
| |
| |
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
|
|
| |
| return self.get_logits(x_BLC[:, :l_end], cond_BD) |
|
|
| @torch.no_grad() |
| def autoregressive_infer_cfg( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| x_BLC_w_prefix_lq=None, |
| |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
| |
| |
| x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() |
| x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) |
| x_BLC_lq = self.lq_linear(x_BLC_lq) |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) |
| |
|
|
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| |
| |
| |
| |
| for si, pn in enumerate(scale_schedule): |
| cfg = cfg_list[si] |
| if si >= trunk_scale: |
| break |
| cur_L += np.array(pn).prod() |
| |
| |
| last_stage = last_stage + x_BLC_lq_list[si] |
|
|
| need_to_pad = 0 |
| attn_fn = None |
| if self.use_flex_attn: |
| |
| |
| |
| attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) |
|
|
| |
| layer_idx = 0 |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| for m in b.module: |
| last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] |
| last_stage = torch.cat((last_stage, last_stage), 0) |
| layer_idx += 1 |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| |
| logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) |
| |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| if vae_type != 0: |
| assert returns_vemb |
| if si < gt_leak: |
| idx_Bld = gt_ls_Bl[si] |
| else: |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
| idx_Bld_list.append(idx_Bld) |
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) |
| last_stage = last_stage.squeeze(-3) |
| if self.apply_spatial_patchify: |
| last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) |
| last_stage = last_stage.reshape(*last_stage.shape[:2], -1) |
| last_stage = torch.permute(last_stage, [0,2,1]) |
| else: |
| summed_codes += codes |
| else: |
| if si < gt_leak: |
| idx_Bl = gt_ls_Bl[si] |
| h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() |
|
|
| |
| h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) |
| ret.append(h_BChw if returns_vemb != 0 else idx_Bl) |
| idx_Bl_list.append(idx_Bl) |
| if si != num_stages_minus_1: |
| accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) |
| |
| if si != num_stages_minus_1: |
| last_stage = self.word_embed(self.norm0_ve(last_stage)) |
| last_stage = last_stage.repeat(bs//B, 1, 1) |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
|
|
| if not ret_img: |
| return ret, idx_Bl_list, [] |
| |
| if vae_type != 0: |
| img = vae.decode(summed_codes.squeeze(-3)) |
| else: |
| img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| return ret, idx_Bl_list, img |
| |
| def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) |
| |
| B = logits_BlV_all.shape[0] |
| |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| summed_codes = 0 |
| num_stages_minus_1 = len(scale_schedule)-1 |
| |
| for si,logits_BlV in enumerate(logits_BlV_list): |
| pn= scale_schedule[si] |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| idx_Bld = GumbelArgmax(logits_BlV, 0.5) |
| tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) |
| tmp_tensor[:,:,1:]=1 |
| idx_Bld = idx_Bld * tmp_tensor |
| idx_Bld = idx_Bld.sum(dim=-1) |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
|
|
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| else: |
| summed_codes += codes |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| img = vae.decode(summed_codes.squeeze(-3)) |
| |
| |
| return img |
| |
| def logits_to_img_gumble(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) |
| |
| B = logits_BlV_all.shape[0] |
| |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| summed_codes = 0 |
| num_stages_minus_1 = len(scale_schedule)-1 |
| |
| for si,logits_BlV in enumerate(logits_BlV_list): |
| pn= scale_schedule[si] |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
|
|
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| else: |
| summed_codes += codes |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| img = vae.decode(summed_codes.squeeze(-3)) |
| |
| |
| return img |
| |
| @for_visualize |
| def vis_key_params(self, ep): |
| return |
| |
| def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): |
| for k in state_dict: |
| if 'cfg_uncond' in k: |
| old, new = state_dict[k], self.cfg_uncond.data |
| min_tlen = min(old.shape[0], new.shape[0]) |
| if min_tlen == old.shape[0]: |
| state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) |
| else: |
| state_dict[k] = old[:min_tlen] |
| |
| for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): |
| state_dict.pop(buf_name, None) |
| if hasattr(self, buf_name): |
| state_dict[buf_name] = getattr(self, buf_name) |
| |
| return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) |
| |
| def special_init( |
| self, |
| aln_init: float, |
| aln_gamma_init: float, |
| scale_head: float, |
| scale_proj: int, |
| ): |
| |
| if isinstance(self.head_nm, AdaLNBeforeHead): |
| self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) |
| if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: |
| self.head_nm.ada_lin[-1].bias.data.zero_() |
| |
| |
| if scale_head >= 0: |
| if isinstance(self.head, nn.Linear): |
| self.head.weight.data.mul_(scale_head) |
| self.head.bias.data.zero_() |
| elif isinstance(self.head, nn.Sequential): |
| self.head[-1].weight.data.mul_(scale_head) |
| self.head[-1].bias.data.zero_() |
| |
| depth = len(self.unregistered_blocks) |
| for block_idx, sab in enumerate(self.unregistered_blocks): |
| sab: Union[SelfAttnBlock, CrossAttnBlock] |
| |
| scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) |
| if scale_proj == 1: |
| if self.t2i: |
| sab.sa.proj.weight.data.mul_(scale) |
| sab.ca.proj.weight.data.mul_(scale) |
| else: |
| sab.attn.proj.weight.data.mul_(scale) |
| sab.ffn.fc2.weight.data.mul_(scale) |
| |
| |
| |
| |
| |
| if hasattr(sab, 'ada_lin'): |
| lin = sab.ada_lin[-1] |
| lin.weight.data[:2*self.C].mul_(aln_gamma_init) |
| lin.weight.data[2*self.C:].mul_(aln_init) |
| if hasattr(lin, 'bias') and lin.bias is not None: |
| lin.bias.data.zero_() |
| elif hasattr(sab, 'ada_gss'): |
| sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) |
| sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) |
| |
| def extra_repr(self): |
| return f'drop_path_rate={self.drop_path_rate}' |
| |
| def get_layer_id_and_scale_exp(self, para_name: str): |
| raise NotImplementedError |
| |
| class FAInfinity(nn.Module): |
| def __init__( |
| self, vae_local, |
| text_channels=0, text_maxlen=0, |
| selecting_idx=None, |
| embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., |
| drop_rate=0., drop_path_rate=0., |
| norm_eps=1e-6, rms_norm=False, |
| shared_aln=False, head_aln=True, |
| cond_drop_rate=0.1, |
| rand_uncond=False, |
| cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, |
| raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
| head_depth=1, |
| top_p=0.0, top_k=0.0, |
| customized_flash_attn=False, fused_mlp=False, fused_norm=False, |
| block_chunks=1, |
| checkpointing=None, |
| pad_to_multiplier=0, |
| use_flex_attn=False, |
| batch_size=2, |
| add_lvl_embeding_only_first_block=1, |
| use_bit_label=1, |
| rope2d_each_sa_layer=0, |
| rope2d_normalized_by_hw=0, |
| pn=None, |
| train_h_div_w_list=None, |
| video_frames=1, |
| always_training_scales=20, |
| apply_spatial_patchify = 0, |
| inference_mode=False, |
| ): |
| |
| self.C = embed_dim |
| self.inference_mode = inference_mode |
| self.apply_spatial_patchify = apply_spatial_patchify |
| if self.apply_spatial_patchify: |
| self.d_vae = vae_local.embed_dim * 4 |
| else: |
| self.d_vae = vae_local.embed_dim |
| self.use_bit_label = use_bit_label |
| self.codebook_dim = self.d_vae |
| self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size |
| self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None |
| self.Ct5 = text_channels |
| self.depth = depth |
| self.num_heads = num_heads |
| self.batch_size = batch_size |
| self.mlp_ratio = mlp_ratio |
| self.cond_drop_rate = cond_drop_rate |
| self.norm_eps = norm_eps |
| self.prog_si = -1 |
| self.pn = pn |
| self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates |
| self.video_frames = video_frames |
| self.always_training_scales = always_training_scales |
|
|
| assert add_lvl_embeding_only_first_block in [0,1] |
| self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block |
| assert rope2d_each_sa_layer in [0,1] |
| self.rope2d_each_sa_layer = rope2d_each_sa_layer |
| self.rope2d_normalized_by_hw = rope2d_normalized_by_hw |
| print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ |
| self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') |
| head_up_method = '' |
| word_patch_size = 1 if head_up_method in {'', 'no'} else 2 |
| if word_patch_size > 1: |
| assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' |
| |
| self.checkpointing = checkpointing |
| self.pad_to_multiplier = max(1, pad_to_multiplier) |
| |
| customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) |
| self.customized_flash_attn = customized_flash_attn and customized_kernel_installed |
| if customized_flash_attn and not customized_kernel_installed: |
| import inspect, warnings |
| file_path = inspect.getsourcefile(flash_attn_func) |
| line_number = inspect.getsourcelines(flash_attn_func)[1] |
| info = ( |
| f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' |
| f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' |
| f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' |
| ) |
| warnings.warn(info, ImportWarning) |
| print(info, flush=True) |
| |
| self.raw_scale_schedule = raw_scale_schedule |
| self.first_l = 1 |
| |
| self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) |
| if self.top_p < 1e-5: self.top_p = 0 |
| if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 |
| |
| t = torch.zeros(dist.get_world_size(), device=dist.get_device()) |
| t[dist.get_rank()] = float(flash_fused_op_installed) |
| dist.barrier() |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' |
| |
| super().__init__() |
| self.rng = torch.Generator(device=dist.get_device()) |
| self.maybe_record_function = nullcontext |
| self.text_maxlen = text_maxlen |
| self.t2i = text_channels != 0 |
| |
| |
| init_std = math.sqrt(1 / self.C / 3) |
| self.norm0_cond = nn.Identity() |
| if self.t2i: |
| self.selecting_idx = None |
| self.num_classes = 0 |
| self.D = self.C |
| |
| cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) |
| rng = torch.Generator(device='cpu') |
| rng.manual_seed(0) |
| torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) |
| cfg_uncond /= self.Ct5 ** 0.5 |
| if rand_uncond: |
| self.register_buffer('cfg_uncond', cfg_uncond) |
| else: |
| self.cfg_uncond = nn.Parameter(cfg_uncond) |
| |
| self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) |
| self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) |
| self.text_proj_for_ca = nn.Sequential( |
| nn.Linear(self.Ct5, self.D), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(self.D, self.D), |
| ) |
| else: |
| if selecting_idx is None: |
| num_classes = 1000 |
| print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') |
| selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) |
| self.selecting_idx = selecting_idx |
| self.num_classes = selecting_idx.shape[-1] |
| self.D = self.C |
| self.class_emb = nn.Embedding(self.num_classes + 1, self.C) |
| nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) |
| |
| self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) |
| nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) |
| if self.rope2d_each_sa_layer: |
| rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) |
| self.rope2d_freqs_grid = rope2d_freqs_grid |
| else: |
| raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') |
| self.lvl_embed = nn.Embedding(15, self.C) |
| nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) |
| |
| |
| norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) |
| self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() |
| self.word_embed = nn.Linear(self.d_vae, self.C) |
| |
| |
| self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() |
| |
| |
| if fused_norm: |
| fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm |
| if fused_norm_func is not None: |
| B = 2 |
| x = torch.randn(B, 1, self.C).requires_grad_(True) |
| scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| |
| del B, x, scale, shift |
| else: |
| fused_norm_func = None |
| |
| |
| self.use_flex_attn = use_flex_attn |
| self.attn_fn_compile_dict = {} |
| self.batch_size = batch_size |
| if self.use_flex_attn: |
| self.attn_fn_compile_dict = self.compile_flex_attn() |
|
|
| self.drop_path_rate = drop_path_rate |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| self.unregistered_blocks = [] |
| for block_idx in range(depth): |
| block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( |
| embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, |
| num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, |
| swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, |
| checkpointing_sa_only=self.checkpointing == 'self-attn', |
| use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
| ) |
| self.unregistered_blocks.append(block) |
| |
| |
| V = self.V |
| if head_aln: |
| self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) |
| self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| else: |
| self.head_nm = MultiInpIdentity() |
| self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| |
| self.num_block_chunks = block_chunks or 1 |
| self.num_blocks_in_a_chunk = depth // block_chunks |
| print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") |
| assert self.num_blocks_in_a_chunk * block_chunks == depth |
| if self.num_block_chunks == 1: |
| self.blocks = nn.ModuleList(self.unregistered_blocks) |
| else: |
| self.block_chunks = nn.ModuleList() |
| for i in range(self.num_block_chunks): |
| self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) |
| |
| |
| self.time_embed = nn.Sequential( |
| nn.Linear(self.C//4, self.C), |
| nn.SiLU(), |
| nn.Linear(self.C, self.C), |
| ) |
| self.lq_linear = nn.Linear(self.C, self.C) |
| |
| |
| print( |
| f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' |
| f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' |
| f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', |
| end='\n\n', flush=True |
| ) |
| |
|
|
| def compile_flex_attn(self): |
| attn_fn_compile_dict = {} |
| for h_div_w in self.train_h_div_w_list: |
| h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] |
| full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] |
| if self.inference_mode: |
| apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) |
| mask_type = "infinity_infer_mask_with_kv_cache" |
| auto_padding = True |
| else: |
| mask_type = 'var' |
| auto_padding = False |
| apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] |
| for scales_num in apply_flex_attn_scales: |
| print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') |
| scale_schedule = full_scale_schedule[:scales_num] |
| scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L, |
| auto_padding=auto_padding) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
|
|
| if self.video_frames > 1: |
| scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
| return attn_fn_compile_dict |
| |
| def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): |
| """ |
| :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) |
| :param cond_BD: shaped (B or batch_size, D or cond_dim) |
| :param tau: temperature |
| :return: logits, shaped (B or batch_size, V or vocabulary_size) |
| """ |
| with torch.amp.autocast('cuda', enabled=False): |
| return self.head(self.head_nm(h.float(), cond_BD.float())) |
|
|
| def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): |
| bs, seq_len, c = feature.shape |
| patch_t, patch_h, patch_w = scale_schedule[scale_ind] |
| t_mul_h_mul_w = patch_t * patch_h * patch_w |
| assert t_mul_h_mul_w + need_to_pad == seq_len |
| feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) |
| return feature |
| |
| def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): |
| ptr = 0 |
| x_BLC_list = [] |
| for scale_ind, patch_t_h_w in enumerate(scale_schedule): |
| scale_seq_len = np.array(patch_t_h_w).prod() |
| x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] |
| ptr += scale_seq_len |
| x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) |
| x_BLC_list.append(x_BLC_this_scale) |
| assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' |
| x_BLC_list.append(x_BLC[:,ptr:]) |
| x_BLC = torch.cat(x_BLC_list, dim=1) |
| return x_BLC |
|
|
| def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False, |
| x_BLC_w_prefix_lq=None, |
| index=None, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] |
| |
| noise = torch.randn_like(x_BLC).to(x_BLC.device) |
| mask = torch.zeros_like(x_BLC, dtype=torch.bool).to(x_BLC.device) |
| index_list = index.cpu().tolist() |
| patch_nums_per_batch = [patch_nums_per_level_acc[j] for j in index_list] |
| for j in range(len(patch_nums_per_batch)): |
| p = patch_nums_per_batch[j] |
| mask[j, :p, :] = 1 |
| x_BLC = torch.where(mask, x_BLC, noise) |
| |
| |
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
| |
| |
| t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) |
| t_emb = self.time_embed(t_emb) |
| |
| x_BLC_lq = self.lq_linear(x_BLC_lq) |
| x_BLC = x_BLC + x_BLC_lq |
| |
| |
| |
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, None, None, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
|
|
| |
| return self.get_logits(x_BLC[:, :l_end], cond_BD) |
|
|
| @torch.no_grad() |
| def autoregressive_infer_cfg( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| x_BLC_w_prefix_lq=None, |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
|
|
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| |
| |
| cfg = cfg_list[0] |
| |
|
|
| |
| |
| |
| |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] |
| x_BLC = torch.randn((bs,patch_nums_per_level_acc[-1],last_stage.shape[-1])).to(last_stage.device) |
| x_BLC[:,:1,:] = last_stage |
| l_end = x_BLC.shape[1] |
| |
| |
| x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() |
| x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) |
| x_BLC_lq = self.lq_linear(x_BLC_lq) |
| x_BLC = x_BLC + x_BLC_lq |
| |
| index = torch.zeros((bs,)).to(x_BLC.device) |
| t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) |
| t_emb = self.time_embed(t_emb) |
| |
| |
|
|
| |
| layer_idx = 0 |
| |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) |
| for m in b.module: |
| |
| x_BLC = m(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=0) |
| |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| x_BLC = cfg * x_BLC[:B] + (1-cfg) * x_BLC[B:] |
| x_BLC = torch.cat((x_BLC, x_BLC), 0) |
| layer_idx += 1 |
| |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| logits_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD).mul(1/tau_list[0]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(x_BLC[:B, :l_end], cond_BD[:B]).mul(1/tau_list[0]) |
| |
| |
| img = self.logits_to_img_discrete(logits_BlV_all=logits_BlV, |
| vae=vae, |
| scale_schedule=scale_schedule, |
| top_k=top_k, |
| top_p=top_p, |
| g_seed=g_seed) |
| |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| |
| return None,None,img |
| |
| @torch.no_grad() |
| def autoregressive_infer_cfg_multi_step( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| x_BLC_w_prefix_lq=None, |
| index_list=None |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
| |
| assert index_list[0] == 0 |
|
|
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| |
| |
| cfg = cfg_list[0] |
| |
| |
| |
| |
| |
| |
| |
| |
| x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() |
| x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) |
| x_BLC_lq = self.lq_linear(x_BLC_lq) |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] |
| |
| x_BLC = torch.zeros((bs,patch_nums_per_level_acc[-1],last_stage.shape[-1])).to(last_stage.device) |
| x_BLC[:,:1,:] = last_stage |
| l_end = x_BLC.shape[1] |
| |
| |
| for index in index_list: |
| |
| |
| noise = torch.randn_like(x_BLC).to(x_BLC.device) |
| x_BLC[:,patch_nums_per_level_acc[index]:,:] = noise[:,patch_nums_per_level_acc[index]:,:] |
| |
| |
| x_BLC = x_BLC + x_BLC_lq |
| |
| |
| index_tensor = torch.full((bs,),index).to(x_BLC.device) |
| t_emb = dist.timestep_embedding(index_tensor, self.C//4, repeat_only=False) |
| t_emb = self.time_embed(t_emb) |
| |
| |
| |
| layer_idx = 0 |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) |
| for m in b.module: |
| |
| x_BLC = m(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=0) |
| |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| x_BLC = cfg * x_BLC[:B] + (1-cfg) * x_BLC[B:] |
| x_BLC = torch.cat((x_BLC, x_BLC), 0) |
| layer_idx += 1 |
| |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| logits_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD).mul(1/tau_list[0]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(x_BLC[:B, :l_end], cond_BD[:B]).mul(1/tau_list[0]) |
| |
| if index == index_list[-1]: |
| logits_final = logits_BlV |
| else: |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV= logits_BlV.reshape(tmp_bs, -1, 2) |
| idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| cum_var_input = 0 |
| x_BLC_wo_prefix = [] |
| idx_Bld_list = list(torch.split(idx_Bld,patch_nums_per_level,dim=1)) |
| for si, bit_indices in enumerate(idx_Bld_list): |
| |
| _, _, d_vae = bit_indices.shape |
| bit_indices = bit_indices.reshape((B,vae_scale_schedule[si][0],vae_scale_schedule[si][1],vae_scale_schedule[si][2],d_vae)) |
|
|
| quantized = vae.quantizer.lfq.indices_to_codes(bit_indices, label_type='bit_label') |
| quantized_up = F.interpolate(quantized, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| cum_var_input += quantized_up |
| |
| if si < len(vae_scale_schedule)-1: |
| this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_up) |
| if self.apply_spatial_patchify: |
| this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) |
| x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) |
|
|
| x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1) |
| x_BLC = torch.cat((last_stage, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| assert x_BLC.shape[1] == l_end |
|
|
|
|
| img = self.logits_to_img_discrete(logits_BlV_all=logits_final, |
| vae=vae, |
| scale_schedule=scale_schedule, |
| top_k=top_k, |
| top_p=top_p, |
| g_seed=g_seed) |
| |
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| |
| return None,None,img |
| |
| def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) |
| |
| B = logits_BlV_all.shape[0] |
| |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| summed_codes = 0 |
| num_stages_minus_1 = len(scale_schedule)-1 |
| |
| for si,logits_BlV in enumerate(logits_BlV_list): |
| pn= scale_schedule[si] |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| |
| |
| |
| idx_Bld = GumbelArgmax(logits_BlV, 0.5) |
| tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) |
| tmp_tensor[:,:,1:]=1 |
| idx_Bld = idx_Bld * tmp_tensor |
| idx_Bld = idx_Bld.sum(dim=-1) |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
|
|
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| else: |
| summed_codes += codes |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| img = vae.decode(summed_codes.squeeze(-3)) |
| |
| |
| return img |
|
|
| def logits_to_img_discrete(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) |
| |
| B = logits_BlV_all.shape[0] |
| |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| summed_codes = 0 |
| num_stages_minus_1 = len(scale_schedule)-1 |
| |
| for si,logits_BlV in enumerate(logits_BlV_list): |
| pn= scale_schedule[si] |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
|
|
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| else: |
| summed_codes += codes |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| img = vae.decode(summed_codes.squeeze(-3)) |
| |
| |
| return img |
| |
| |
| @for_visualize |
| def vis_key_params(self, ep): |
| return |
| |
| def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): |
| for k in state_dict: |
| if 'cfg_uncond' in k: |
| old, new = state_dict[k], self.cfg_uncond.data |
| min_tlen = min(old.shape[0], new.shape[0]) |
| if min_tlen == old.shape[0]: |
| state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) |
| else: |
| state_dict[k] = old[:min_tlen] |
| |
| for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): |
| state_dict.pop(buf_name, None) |
| if hasattr(self, buf_name): |
| state_dict[buf_name] = getattr(self, buf_name) |
| |
| return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) |
| |
| def special_init( |
| self, |
| aln_init: float, |
| aln_gamma_init: float, |
| scale_head: float, |
| scale_proj: int, |
| ): |
| |
| if isinstance(self.head_nm, AdaLNBeforeHead): |
| self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) |
| if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: |
| self.head_nm.ada_lin[-1].bias.data.zero_() |
| |
| |
| if scale_head >= 0: |
| if isinstance(self.head, nn.Linear): |
| self.head.weight.data.mul_(scale_head) |
| self.head.bias.data.zero_() |
| elif isinstance(self.head, nn.Sequential): |
| self.head[-1].weight.data.mul_(scale_head) |
| self.head[-1].bias.data.zero_() |
| |
| depth = len(self.unregistered_blocks) |
| for block_idx, sab in enumerate(self.unregistered_blocks): |
| sab: Union[SelfAttnBlock, CrossAttnBlock] |
| |
| scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) |
| if scale_proj == 1: |
| if self.t2i: |
| sab.sa.proj.weight.data.mul_(scale) |
| sab.ca.proj.weight.data.mul_(scale) |
| else: |
| sab.attn.proj.weight.data.mul_(scale) |
| sab.ffn.fc2.weight.data.mul_(scale) |
| |
| |
| |
| |
| |
| if hasattr(sab, 'ada_lin'): |
| lin = sab.ada_lin[-1] |
| lin.weight.data[:2*self.C].mul_(aln_gamma_init) |
| lin.weight.data[2*self.C:].mul_(aln_init) |
| if hasattr(lin, 'bias') and lin.bias is not None: |
| lin.bias.data.zero_() |
| elif hasattr(sab, 'ada_gss'): |
| sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) |
| sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) |
| |
| def extra_repr(self): |
| return f'drop_path_rate={self.drop_path_rate}' |
| |
| def get_layer_id_and_scale_exp(self, para_name: str): |
| raise NotImplementedError |
|
|
| class FInfinity(nn.Module): |
| def __init__( |
| self, vae_local, |
| text_channels=0, text_maxlen=0, |
| selecting_idx=None, |
| embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., |
| drop_rate=0., drop_path_rate=0., |
| norm_eps=1e-6, rms_norm=False, |
| shared_aln=False, head_aln=True, |
| cond_drop_rate=0.1, |
| rand_uncond=False, |
| cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, |
| raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
| head_depth=1, |
| top_p=0.0, top_k=0.0, |
| customized_flash_attn=False, fused_mlp=False, fused_norm=False, |
| block_chunks=1, |
| checkpointing=None, |
| pad_to_multiplier=0, |
| use_flex_attn=False, |
| batch_size=2, |
| add_lvl_embeding_only_first_block=1, |
| use_bit_label=1, |
| rope2d_each_sa_layer=0, |
| rope2d_normalized_by_hw=0, |
| pn=None, |
| train_h_div_w_list=None, |
| video_frames=1, |
| always_training_scales=20, |
| apply_spatial_patchify = 0, |
| inference_mode=False, |
| ): |
| |
| self.C = embed_dim |
| self.inference_mode = inference_mode |
| self.apply_spatial_patchify = apply_spatial_patchify |
| if self.apply_spatial_patchify: |
| self.d_vae = vae_local.embed_dim * 4 |
| else: |
| self.d_vae = vae_local.embed_dim |
| self.use_bit_label = use_bit_label |
| self.codebook_dim = self.d_vae |
| self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size |
| self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None |
| self.Ct5 = text_channels |
| self.depth = depth |
| self.num_heads = num_heads |
| self.batch_size = batch_size |
| self.mlp_ratio = mlp_ratio |
| self.cond_drop_rate = cond_drop_rate |
| self.norm_eps = norm_eps |
| self.prog_si = -1 |
| self.pn = pn |
| self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates |
| self.video_frames = video_frames |
| self.always_training_scales = always_training_scales |
|
|
| assert add_lvl_embeding_only_first_block in [0,1] |
| self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block |
| assert rope2d_each_sa_layer in [0,1] |
| self.rope2d_each_sa_layer = rope2d_each_sa_layer |
| self.rope2d_normalized_by_hw = rope2d_normalized_by_hw |
| print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ |
| self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') |
| head_up_method = '' |
| word_patch_size = 1 if head_up_method in {'', 'no'} else 2 |
| if word_patch_size > 1: |
| assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' |
| |
| self.checkpointing = checkpointing |
| self.pad_to_multiplier = max(1, pad_to_multiplier) |
| |
| customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) |
| self.customized_flash_attn = customized_flash_attn and customized_kernel_installed |
| if customized_flash_attn and not customized_kernel_installed: |
| import inspect, warnings |
| file_path = inspect.getsourcefile(flash_attn_func) |
| line_number = inspect.getsourcelines(flash_attn_func)[1] |
| info = ( |
| f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' |
| f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' |
| f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' |
| ) |
| warnings.warn(info, ImportWarning) |
| print(info, flush=True) |
| |
| self.raw_scale_schedule = raw_scale_schedule |
| self.first_l = 1 |
| |
| self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) |
| if self.top_p < 1e-5: self.top_p = 0 |
| if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 |
| |
| t = torch.zeros(dist.get_world_size(), device=dist.get_device()) |
| t[dist.get_rank()] = float(flash_fused_op_installed) |
| dist.barrier() |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' |
| |
| super().__init__() |
| self.rng = torch.Generator(device=dist.get_device()) |
| self.maybe_record_function = nullcontext |
| self.text_maxlen = text_maxlen |
| self.t2i = text_channels != 0 |
| |
| |
| init_std = math.sqrt(1 / self.C / 3) |
| self.norm0_cond = nn.Identity() |
|
|
| self.time_embed = nn.Sequential( |
| nn.Linear(self.C//4, self.C), |
| nn.SiLU(), |
| nn.Linear(self.C, self.C), |
| ) |
|
|
| if self.t2i: |
| self.selecting_idx = None |
| self.num_classes = 0 |
| self.D = self.C |
| |
| cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) |
| rng = torch.Generator(device='cpu') |
| rng.manual_seed(0) |
| torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) |
| cfg_uncond /= self.Ct5 ** 0.5 |
| if rand_uncond: |
| self.register_buffer('cfg_uncond', cfg_uncond) |
| else: |
| self.cfg_uncond = nn.Parameter(cfg_uncond) |
| |
| self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) |
| self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) |
| self.text_proj_for_ca = nn.Sequential( |
| nn.Linear(self.Ct5, self.D), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(self.D, self.D), |
| ) |
| else: |
| if selecting_idx is None: |
| num_classes = 1000 |
| print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') |
| selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) |
| self.selecting_idx = selecting_idx |
| self.num_classes = selecting_idx.shape[-1] |
| self.D = self.C |
| self.class_emb = nn.Embedding(self.num_classes + 1, self.C) |
| nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) |
| |
| self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) |
| nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) |
| if self.rope2d_each_sa_layer: |
| rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) |
| self.rope2d_freqs_grid = rope2d_freqs_grid |
| else: |
| raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') |
| self.lvl_embed = nn.Embedding(15, self.C) |
| nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) |
| |
| |
| norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) |
| self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() |
| self.word_embed = nn.Linear(self.d_vae, self.C) |
| |
| |
| self.norm0_ve_lq = norm_layer(self.d_vae) if nm0 else nn.Identity() |
| self.word_embed_lq = nn.Linear(self.d_vae, self.C) |
| |
| |
| self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() |
| |
| |
| if fused_norm: |
| fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm |
| if fused_norm_func is not None: |
| B = 2 |
| x = torch.randn(B, 1, self.C).requires_grad_(True) |
| scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| |
| del B, x, scale, shift |
| else: |
| fused_norm_func = None |
| |
| |
| self.use_flex_attn = use_flex_attn |
| self.attn_fn_compile_dict = {} |
| self.batch_size = batch_size |
| if self.use_flex_attn: |
| self.attn_fn_compile_dict = self.compile_flex_attn() |
|
|
| self.drop_path_rate = drop_path_rate |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| self.unregistered_blocks = [] |
| for block_idx in range(depth): |
| block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( |
| embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, |
| num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, |
| swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, |
| checkpointing_sa_only=self.checkpointing == 'self-attn', |
| use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
| ) |
| self.unregistered_blocks.append(block) |
| |
| |
| V = self.V |
| if head_aln: |
| self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) |
| self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| else: |
| self.head_nm = MultiInpIdentity() |
| self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) |
| |
| self.num_block_chunks = block_chunks or 1 |
| self.num_blocks_in_a_chunk = depth // block_chunks |
| print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") |
| assert self.num_blocks_in_a_chunk * block_chunks == depth |
| if self.num_block_chunks == 1: |
| self.blocks = nn.ModuleList(self.unregistered_blocks) |
| else: |
| self.block_chunks = nn.ModuleList() |
| for i in range(self.num_block_chunks): |
| self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) |
| print( |
| f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' |
| f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' |
| f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', |
| end='\n\n', flush=True |
| ) |
| self.car_control_convs = ControlConditionEmbedding(conditioning_embedding_channels=self.C) |
| |
| |
|
|
| def compile_flex_attn(self): |
| attn_fn_compile_dict = {} |
| for h_div_w in self.train_h_div_w_list: |
| h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] |
| full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] |
| if self.inference_mode: |
| apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) |
| mask_type = "infinity_infer_mask_with_kv_cache" |
| auto_padding = True |
| else: |
| mask_type = 'var' |
| auto_padding = False |
| apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] |
| for scales_num in apply_flex_attn_scales: |
| print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') |
| scale_schedule = full_scale_schedule[:scales_num] |
| scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L, |
| auto_padding=auto_padding) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
|
|
| if self.video_frames > 1: |
| scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] |
| patchs_nums_tuple = tuple(scale_schedule) |
| SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) |
| aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L |
| attn_fn = FlexAttn(block_scales = patchs_nums_tuple, |
| mask_type = mask_type, |
| B = self.batch_size, |
| H = self.num_heads, |
| L = aligned_L) |
| attn_fn_compile_dict[patchs_nums_tuple] = attn_fn |
| return attn_fn_compile_dict |
| |
| def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): |
| """ |
| :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) |
| :param cond_BD: shaped (B or batch_size, D or cond_dim) |
| :param tau: temperature |
| :return: logits, shaped (B or batch_size, V or vocabulary_size) |
| """ |
| with torch.amp.autocast('cuda', enabled=False): |
| return self.head(self.head_nm(h.float(), cond_BD.float())) |
|
|
| def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): |
| bs, seq_len, c = feature.shape |
| patch_t, patch_h, patch_w = scale_schedule[scale_ind] |
| t_mul_h_mul_w = patch_t * patch_h * patch_w |
| assert t_mul_h_mul_w + need_to_pad == seq_len |
| feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) |
| return feature |
| |
| def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): |
| ptr = 0 |
| x_BLC_list = [] |
| for scale_ind, patch_t_h_w in enumerate(scale_schedule): |
| scale_seq_len = np.array(patch_t_h_w).prod() |
| x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] |
| ptr += scale_seq_len |
| x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) |
| x_BLC_list.append(x_BLC_this_scale) |
| |
| |
| |
| x_BLC_list.append(x_BLC[:,ptr:]) |
| x_BLC = torch.cat(x_BLC_list, dim=1) |
| return x_BLC |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False,lq_images=None,index=None, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| lq_images = lq_images.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| x_BLC_lq = self.car_control_convs(lq_images) |
| x_BLC_lq = x_BLC_lq.view(B, self.C, -1).transpose(1, 2).contiguous() |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] |
| |
| noise = torch.randn_like(x_BLC).to(x_BLC.device) |
| mask = torch.zeros_like(x_BLC, dtype=torch.bool).to(x_BLC.device) |
| index_list = index.cpu().tolist() |
| patch_nums_per_batch = [patch_nums_per_level_acc[j] for j in index_list] |
| for j in range(len(patch_nums_per_batch)): |
| p = patch_nums_per_batch[j] |
| mask[j, :p, :] = 1 |
| x_BLC = torch.where(mask, x_BLC, noise) |
|
|
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| |
| |
| |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| |
| |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
|
|
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| |
| t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) |
| t_emb = self.time_embed(t_emb) |
| |
| x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) |
| x_BLC = x_BLC + t_emb.unsqueeze(1) |
|
|
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, None, None, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| |
| return self.get_logits(x_BLC[:, :l_end], cond_BD) |
| |
| def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) |
| |
| B = logits_BlV_all.shape[0] |
| |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| summed_codes = 0 |
| num_stages_minus_1 = len(scale_schedule)-1 |
| |
| for si,logits_BlV in enumerate(logits_BlV_list): |
| pn= scale_schedule[si] |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| idx_Bld = GumbelArgmax(logits_BlV, 0.5) |
| tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) |
| tmp_tensor[:,:,1:]=1 |
| idx_Bld = idx_Bld * tmp_tensor |
| idx_Bld = idx_Bld.sum(dim=-1) |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
|
|
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| else: |
| summed_codes += codes |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| img = vae.decode(summed_codes.squeeze(-3)) |
| |
| |
| return img |
| |
| def logits_to_img_discrete(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) |
| |
| B = logits_BlV_all.shape[0] |
| |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| summed_codes = 0 |
| num_stages_minus_1 = len(scale_schedule)-1 |
| |
| for si,logits_BlV in enumerate(logits_BlV_list): |
| pn= scale_schedule[si] |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| |
| idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
|
|
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| else: |
| summed_codes += codes |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| img = vae.decode(summed_codes.squeeze(-3)) |
| |
| |
| return img |
| |
| @torch.no_grad() |
| def autoregressive_infer_cfg( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| lq_images=None, |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
|
|
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| |
| |
| cfg = cfg_list[0] |
| |
|
|
| x_BLC_lq = self.car_control_convs(lq_images) |
| x_BLC_lq = x_BLC_lq.view(B, self.C, -1).transpose(1, 2).contiguous() |
| if cfg!=1: |
| x_BLC_lq = torch.cat([x_BLC_lq,x_BLC_lq],dim=0) |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] |
| x_BLC = torch.randn((bs,patch_nums_per_level_acc[-1],last_stage.shape[-1])).to(last_stage.device) |
| x_BLC[:,:1,:] = last_stage |
| l_end = x_BLC.shape[1] |
| |
| index = torch.zeros((bs,)).to(x_BLC.device) |
| t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) |
| t_emb = self.time_embed(t_emb) |
| |
| x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) |
| x_BLC = x_BLC + t_emb.unsqueeze(1) |
| |
| layer_idx = 0 |
| |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) |
| for m in b.module: |
| |
| x_BLC = m(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=0) |
| |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| x_BLC = cfg * x_BLC[:B] + (1-cfg) * x_BLC[B:] |
| x_BLC = torch.cat((x_BLC, x_BLC), 0) |
| layer_idx += 1 |
| |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| logits_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD).mul(1/tau_list[0]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(x_BLC[:B, :l_end], cond_BD[:B]).mul(1/tau_list[0]) |
| |
| |
| img = self.logits_to_img_discrete(logits_BlV_all=logits_BlV, |
| vae=vae, |
| scale_schedule=scale_schedule, |
| top_k=top_k, |
| top_p=top_p, |
| g_seed=g_seed) |
| |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| |
| return None,None,img |
| |
| @torch.no_grad() |
| def autoregressive_infer_cfg_multi_step( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| lq_images=None, |
| index_list=None |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
| |
| assert index_list[0] == 0 |
|
|
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| |
| |
| cfg = cfg_list[0] |
| |
| |
| x_BLC_lq = self.car_control_convs(lq_images) |
| |
| x_BLC_lq = x_BLC_lq.view(B, self.C, -1).transpose(1, 2).contiguous() |
| if cfg!=1: |
| x_BLC_lq = torch.cat([x_BLC_lq,x_BLC_lq],dim=0) |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] |
| |
| |
| x_BLC = torch.zeros((bs,patch_nums_per_level_acc[-1],last_stage.shape[-1])).to(last_stage.device) |
| x_BLC[:,:1,:] = last_stage |
| l_end = x_BLC.shape[1] |
| |
| |
| for index in index_list: |
| |
| |
| noise = torch.randn_like(x_BLC).to(x_BLC.device) |
| x_BLC[:,patch_nums_per_level_acc[index]:,:] = noise[:,patch_nums_per_level_acc[index]:,:] |
| |
| |
| x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) |
| |
| |
| index_tensor = torch.full((bs,),index).to(x_BLC.device) |
| t_emb = dist.timestep_embedding(index_tensor, self.C//4, repeat_only=False) |
| t_emb = self.time_embed(t_emb) |
| |
| x_BLC = x_BLC + t_emb |
| |
| layer_idx = 0 |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) |
| for m in b.module: |
| |
| x_BLC = m(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=0) |
| |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| x_BLC = cfg * x_BLC[:B] + (1-cfg) * x_BLC[B:] |
| x_BLC = torch.cat((x_BLC, x_BLC), 0) |
| layer_idx += 1 |
| |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| logits_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD).mul(1/tau_list[0]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(x_BLC[:B, :l_end], cond_BD[:B]).mul(1/tau_list[0]) |
| |
| if index == index_list[-1]: |
| logits_final = logits_BlV |
| else: |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV= logits_BlV.reshape(tmp_bs, -1, 2) |
| idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| cum_var_input = 0 |
| x_BLC_wo_prefix = [] |
| idx_Bld_list = list(torch.split(idx_Bld,patch_nums_per_level,dim=1)) |
| for si, bit_indices in enumerate(idx_Bld_list): |
| |
| _, _, d_vae = bit_indices.shape |
| bit_indices = bit_indices.reshape((B,vae_scale_schedule[si][0],vae_scale_schedule[si][1],vae_scale_schedule[si][2],d_vae)) |
|
|
| quantized = vae.quantizer.lfq.indices_to_codes(bit_indices, label_type='bit_label') |
| quantized_up = F.interpolate(quantized, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| cum_var_input += quantized_up |
| |
| if si < len(vae_scale_schedule)-1: |
| this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_up) |
| if self.apply_spatial_patchify: |
| this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) |
| x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) |
|
|
| x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1) |
| x_BLC = torch.cat((last_stage, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| assert x_BLC.shape[1] == l_end |
|
|
|
|
| img = self.logits_to_img_discrete(logits_BlV_all=logits_final, |
| vae=vae, |
| scale_schedule=scale_schedule, |
| top_k=top_k, |
| top_p=top_p, |
| g_seed=g_seed) |
| |
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| |
| return None,None,img |
| |
| @for_visualize |
| def vis_key_params(self, ep): |
| return |
| |
| def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): |
| for k in state_dict: |
| if 'cfg_uncond' in k: |
| old, new = state_dict[k], self.cfg_uncond.data |
| min_tlen = min(old.shape[0], new.shape[0]) |
| if min_tlen == old.shape[0]: |
| state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) |
| else: |
| state_dict[k] = old[:min_tlen] |
| |
| for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): |
| state_dict.pop(buf_name, None) |
| if hasattr(self, buf_name): |
| state_dict[buf_name] = getattr(self, buf_name) |
| |
| return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) |
| |
| def special_init( |
| self, |
| aln_init: float, |
| aln_gamma_init: float, |
| scale_head: float, |
| scale_proj: int, |
| ): |
| |
| if isinstance(self.head_nm, AdaLNBeforeHead): |
| self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) |
| if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: |
| self.head_nm.ada_lin[-1].bias.data.zero_() |
| |
| |
| if scale_head >= 0: |
| if isinstance(self.head, nn.Linear): |
| self.head.weight.data.mul_(scale_head) |
| self.head.bias.data.zero_() |
| elif isinstance(self.head, nn.Sequential): |
| self.head[-1].weight.data.mul_(scale_head) |
| self.head[-1].bias.data.zero_() |
| |
| depth = len(self.unregistered_blocks) |
| for block_idx, sab in enumerate(self.unregistered_blocks): |
| sab: Union[SelfAttnBlock, CrossAttnBlock] |
| |
| scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) |
| if scale_proj == 1: |
| if self.t2i: |
| sab.sa.proj.weight.data.mul_(scale) |
| sab.ca.proj.weight.data.mul_(scale) |
| else: |
| sab.attn.proj.weight.data.mul_(scale) |
| sab.ffn.fc2.weight.data.mul_(scale) |
| |
| |
| |
| |
| |
| if hasattr(sab, 'ada_lin'): |
| lin = sab.ada_lin[-1] |
| lin.weight.data[:2*self.C].mul_(aln_gamma_init) |
| lin.weight.data[2*self.C:].mul_(aln_init) |
| if hasattr(lin, 'bias') and lin.bias is not None: |
| lin.bias.data.zero_() |
| elif hasattr(sab, 'ada_gss'): |
| sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) |
| sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) |
| |
| def extra_repr(self): |
| return f'drop_path_rate={self.drop_path_rate}' |
| |
| def get_layer_id_and_scale_exp(self, para_name: str): |
| raise NotImplementedError |
|
|
| |
| class ControlConditionEmbedding(nn.Module): |
| def __init__( |
| self, |
| conditioning_embedding_channels: int, |
| conditioning_channels: int = 3, |
| block_out_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), |
| ): |
| super().__init__() |
|
|
| self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) |
|
|
| self.blocks = nn.ModuleList([]) |
|
|
| for i in range(len(block_out_channels) - 1): |
| channel_in = block_out_channels[i] |
| channel_out = block_out_channels[i + 1] |
| self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) |
| self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) |
|
|
| self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) |
|
|
| def forward(self, conditioning): |
| embedding = self.conv_in(conditioning) |
| embedding = F.silu(embedding) |
|
|
| for block in self.blocks: |
| embedding = block(embedding) |
| embedding = F.silu(embedding) |
|
|
| embedding = self.conv_out(embedding) |
|
|
| return embedding |
|
|
| class ControlConditionEmbedding_patch_size_32(nn.Module): |
| def __init__( |
| self, |
| conditioning_embedding_channels: int, |
| conditioning_channels: int = 3, |
| block_out_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), |
| ): |
| super().__init__() |
|
|
| self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) |
|
|
| self.blocks = nn.ModuleList([]) |
|
|
| for i in range(len(block_out_channels) - 1): |
| channel_in = block_out_channels[i] |
| channel_out = block_out_channels[i + 1] |
| self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) |
| self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) |
|
|
| self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1,stride=2) |
|
|
| def forward(self, conditioning): |
| embedding = self.conv_in(conditioning) |
| embedding = F.silu(embedding) |
|
|
| for block in self.blocks: |
| embedding = block(embedding) |
| embedding = F.silu(embedding) |
|
|
| embedding = self.conv_out(embedding) |
|
|
| return embedding |
|
|
| class FP32_Layernorm(nn.LayerNorm): |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| origin_dtype = inputs.dtype |
| return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), |
| self.eps).to(origin_dtype) |
|
|
| class CInfinity(Infinity): |
| def __init__( |
| self, vae_local, |
| text_channels=0, text_maxlen=0, |
| selecting_idx=None, |
| embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., |
| drop_rate=0., drop_path_rate=0., |
| norm_eps=1e-6, rms_norm=False, |
| shared_aln=False, head_aln=True, |
| cond_drop_rate=0.1, |
| rand_uncond=False, |
| cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, |
| raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
| head_depth=1, |
| top_p=0.0, top_k=0.0, |
| customized_flash_attn=False, fused_mlp=False, fused_norm=False, |
| block_chunks=1, |
| checkpointing=None, |
| pad_to_multiplier=0, |
| use_flex_attn=False, |
| batch_size=2, |
| add_lvl_embeding_only_first_block=1, |
| use_bit_label=1, |
| rope2d_each_sa_layer=0, |
| rope2d_normalized_by_hw=0, |
| pn=None, |
| train_h_div_w_list=None, |
| video_frames=1, |
| always_training_scales=20, |
| apply_spatial_patchify = 0, |
| inference_mode=False, |
| |
| ): |
| super(CInfinity,self).__init__(vae_local, |
| text_channels, text_maxlen, |
| selecting_idx, |
| embed_dim, depth, num_heads, mlp_ratio, |
| drop_rate, drop_path_rate, |
| norm_eps, rms_norm, |
| shared_aln, head_aln, |
| cond_drop_rate, |
| rand_uncond, |
| cross_attn_layer_scale, nm0, tau, cos_attn, swiglu, |
| raw_scale_schedule, |
| head_depth, |
| top_p, top_k, |
| customized_flash_attn, fused_mlp, fused_norm, |
| block_chunks, |
| checkpointing, |
| pad_to_multiplier, |
| use_flex_attn, |
| batch_size, |
| add_lvl_embeding_only_first_block, |
| use_bit_label, |
| rope2d_each_sa_layer, |
| rope2d_normalized_by_hw, |
| pn, |
| train_h_div_w_list, |
| video_frames, |
| always_training_scales, |
| apply_spatial_patchify, |
| inference_mode,) |
| |
| |
| conv_in_kernel = 3 |
| conv_in_padding = (conv_in_kernel - 1) // 2 |
| self.car_var_conv = nn.Conv2d(self.C, self.C, kernel_size=conv_in_kernel, padding=conv_in_padding) |
| norm_layer = partial(nn.LayerNorm, eps=norm_eps) |
| self.drop_path_rate = drop_path_rate |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
|
|
| if fused_norm: |
| fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm |
| if fused_norm_func is not None: |
| B = 2 |
| x = torch.randn(B, 1, self.C).requires_grad_(True) |
| scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| |
| del B, x, scale, shift |
| else: |
| fused_norm_func = None |
| |
| self.car_unregistered_blocks = [] |
| for block_idx in range(depth//2): |
| block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( |
| embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, |
| num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, |
| swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, |
| checkpointing_sa_only=self.checkpointing == 'self-attn', |
| use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
| ) |
| self.car_unregistered_blocks.append(block) |
| |
| if self.num_block_chunks == 1: |
| self.car_blocks = nn.ModuleList(self.car_unregistered_blocks) |
| else: |
| self.car_block_chunks = nn.ModuleList() |
| for i in range(self.num_block_chunks//2): |
| self.car_block_chunks.append(MultipleLayers(self.car_unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) |
| |
| car_norm_layer = FP32_Layernorm |
| car_skip_norm = [] |
| car_skip_linear = [] |
| |
| for _ in range(depth // 2): |
| car_skip_norm.append(car_norm_layer(2 * self.C, elementwise_affine=True, eps=1e-6)) |
| car_skip_linear.append(nn.Linear(2 * self.C, self.C)) |
|
|
| |
| |
| |
| |
| self.car_skip_norm = nn.ModuleList(car_skip_norm) |
| self.car_skip_linear = nn.ModuleList(car_skip_linear) |
|
|
|
|
| def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False,x_BLC_w_prefix_lq=None, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) |
| x_BLC_list = list(torch.split(x_BLC,patch_nums_per_level,dim=1)) |
| x_BLC_lq_list_new = [] |
| CVae = x_BLC.shape[-1] |
| for si, (pn2, var_x) in enumerate(zip(patch_nums_per_level, x_BLC_list)): |
| pn = int(pn2**0.5) |
| var_x = var_x.transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) |
| var_x = self.car_var_conv(var_x) |
| car_x = var_x + x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) |
| car_x = car_x.view(B, CVae, -1).transpose(1, 2).contiguous() |
| x_BLC_lq_list_new.append(car_x) |
| x_BLC_lq = torch.cat(x_BLC_lq_list_new,dim=1) |
| |
| |
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| |
| x_BLC_lq = F.pad(x_BLC_lq, (0, 0, 0, need_to_pad)) |
| |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| x_BLC_lq = F.pad(x_BLC_lq,(0, 0, 0, need_to_pad)) |
| |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
|
|
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| |
| control_residual_f = [] |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.car_blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC_lq = torch.utils.checkpoint.checkpoint(b, x_BLC_lq, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC_lq = b(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| control_residual_f.append(x_BLC_lq) |
| else: |
| for i, chunk in enumerate(self.car_block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) |
| x_BLC_lq = chunk(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| control_residual_f.append(x_BLC_lq) |
| |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| |
| if i >= len(self.blocks) // 2: |
| con_f = control_residual_f.pop() |
| cat = torch.cat([x_BLC, con_f], dim=-1) |
| cat = self.car_skip_norm[i - len(self.blocks) // 2](cat) |
| x_BLC = self.car_skip_linear[i - len(self.blocks) // 2](cat) |
| |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| |
| if i >= len(self.block_chunks) // 2: |
| con_f = control_residual_f.pop() |
| cat = torch.cat([x_BLC, con_f], dim=-1) |
| cat = self.car_skip_norm[i - len(self.block_chunks) // 2](cat) |
| x_BLC = self.car_skip_linear[i - len(self.block_chunks) // 2](cat) |
|
|
| |
| |
| |
| |
| |
| |
| |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| |
| return self.get_logits(x_BLC[:, :l_end], cond_BD) |
| |
| @torch.no_grad() |
| def car_inference( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| x_BLC_w_prefix_lq=None |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
|
|
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| |
| |
| |
| |
| |
| |
| |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
| |
| x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.expand(bs,-1,-1) |
| x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) |
| |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| if inference_mode: |
| for b in self.car_unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.car_block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| for si, pn in enumerate(scale_schedule): |
| cfg = cfg_list[si] |
| if si >= trunk_scale: |
| break |
| cur_L += np.array(pn).prod() |
|
|
| need_to_pad = 0 |
| attn_fn = None |
| if self.use_flex_attn: |
| |
| |
| |
| attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) |
|
|
| |
| layer_idx = 0 |
| |
| control_residual_f = [] |
| if x_BLC_lq_list is not None: |
| last_stage_channel = last_stage.shape[-1] |
| control_x = x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(bs, last_stage_channel, pn[1], pn[2]) |
| var_x = last_stage.transpose(1, 2).contiguous().reshape(bs, last_stage_channel, pn[1], pn[2]) |
| var_x = self.car_var_conv(var_x) |
| control_x = var_x + control_x |
| control_x = control_x.view(bs, last_stage_channel, -1).transpose(1, 2) |
| |
| |
| |
| for i, chunk in enumerate(self.car_block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| control_x = self.add_lvl_embeding(control_x, si, scale_schedule, need_to_pad=need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| control_x = self.add_lvl_embeding(control_x, si, scale_schedule, need_to_pad=need_to_pad) |
| |
| |
| control_x = chunk(x=control_x, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| |
| control_residual_f.append(control_x) |
| |
| |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| |
| |
| if block_idx >= len(self.block_chunks) // 2: |
| con_f = control_residual_f.pop() |
| cat = torch.cat([last_stage, con_f], dim=-1) |
| cat = self.car_skip_norm[block_idx - len(self.block_chunks) // 2](cat) |
| last_stage = self.car_skip_linear[block_idx - len(self.block_chunks) // 2](cat) |
| |
| |
| |
| for m in b.module: |
| last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] |
| last_stage = torch.cat((last_stage, last_stage), 0) |
| layer_idx += 1 |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| |
| logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) |
| |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| if vae_type != 0: |
| assert returns_vemb |
| if si < gt_leak: |
| idx_Bld = gt_ls_Bl[si] |
| else: |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
| idx_Bld_list.append(idx_Bld) |
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) |
| last_stage = last_stage.squeeze(-3) |
| if self.apply_spatial_patchify: |
| last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) |
| last_stage = last_stage.reshape(*last_stage.shape[:2], -1) |
| last_stage = torch.permute(last_stage, [0,2,1]) |
| else: |
| summed_codes += codes |
| else: |
| if si < gt_leak: |
| idx_Bl = gt_ls_Bl[si] |
| h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() |
|
|
| |
| h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) |
| ret.append(h_BChw if returns_vemb != 0 else idx_Bl) |
| idx_Bl_list.append(idx_Bl) |
| if si != num_stages_minus_1: |
| accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) |
| |
| if si != num_stages_minus_1: |
| last_stage = self.word_embed(self.norm0_ve(last_stage)) |
| last_stage = last_stage.repeat(bs//B, 1, 1) |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
| |
| if inference_mode: |
| for b in self.car_unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.car_block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
| |
|
|
| if not ret_img: |
| return ret, idx_Bl_list, [] |
| |
| if vae_type != 0: |
| img = vae.decode(summed_codes.squeeze(-3)) |
| else: |
| img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| return ret, idx_Bl_list, img |
|
|
| class CInfinity2(Infinity): |
| def __init__( |
| self, vae_local, |
| text_channels=0, text_maxlen=0, |
| selecting_idx=None, |
| embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., |
| drop_rate=0., drop_path_rate=0., |
| norm_eps=1e-6, rms_norm=False, |
| shared_aln=False, head_aln=True, |
| cond_drop_rate=0.1, |
| rand_uncond=False, |
| cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, |
| raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
| head_depth=1, |
| top_p=0.0, top_k=0.0, |
| customized_flash_attn=False, fused_mlp=False, fused_norm=False, |
| block_chunks=1, |
| checkpointing=None, |
| pad_to_multiplier=0, |
| use_flex_attn=False, |
| batch_size=2, |
| add_lvl_embeding_only_first_block=1, |
| use_bit_label=1, |
| rope2d_each_sa_layer=0, |
| rope2d_normalized_by_hw=0, |
| pn=None, |
| train_h_div_w_list=None, |
| video_frames=1, |
| always_training_scales=20, |
| apply_spatial_patchify = 0, |
| inference_mode=False, |
| |
| ): |
| super(CInfinity,self).__init__(vae_local, |
| text_channels, text_maxlen, |
| selecting_idx, |
| embed_dim, depth, num_heads, mlp_ratio, |
| drop_rate, drop_path_rate, |
| norm_eps, rms_norm, |
| shared_aln, head_aln, |
| cond_drop_rate, |
| rand_uncond, |
| cross_attn_layer_scale, nm0, tau, cos_attn, swiglu, |
| raw_scale_schedule, |
| head_depth, |
| top_p, top_k, |
| customized_flash_attn, fused_mlp, fused_norm, |
| block_chunks, |
| checkpointing, |
| pad_to_multiplier, |
| use_flex_attn, |
| batch_size, |
| add_lvl_embeding_only_first_block, |
| use_bit_label, |
| rope2d_each_sa_layer, |
| rope2d_normalized_by_hw, |
| pn, |
| train_h_div_w_list, |
| video_frames, |
| always_training_scales, |
| apply_spatial_patchify, |
| inference_mode,) |
| |
| |
| self.car_control_convs = ControlConditionEmbedding(conditioning_embedding_channels=self.C) |
| |
| conv_in_kernel = 3 |
| conv_in_padding = (conv_in_kernel - 1) // 2 |
| self.car_var_conv = nn.Conv2d(self.C, self.C, kernel_size=conv_in_kernel, padding=conv_in_padding) |
| norm_layer = partial(nn.LayerNorm, eps=norm_eps) |
| self.drop_path_rate = drop_path_rate |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
|
|
| if fused_norm: |
| fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm |
| if fused_norm_func is not None: |
| B = 2 |
| x = torch.randn(B, 1, self.C).requires_grad_(True) |
| scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) |
| |
| del B, x, scale, shift |
| else: |
| fused_norm_func = None |
| |
| self.car_unregistered_blocks = [] |
| for block_idx in range(depth//2): |
| block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( |
| embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, |
| num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, |
| swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, |
| checkpointing_sa_only=self.checkpointing == 'self-attn', |
| use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
| ) |
| self.car_unregistered_blocks.append(block) |
| |
| if self.num_block_chunks == 1: |
| self.car_blocks = nn.ModuleList(self.car_unregistered_blocks) |
| else: |
| self.car_block_chunks = nn.ModuleList() |
| for i in range(self.num_block_chunks//2): |
| self.car_block_chunks.append(MultipleLayers(self.car_unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) |
| |
| car_norm_layer = FP32_Layernorm |
| car_skip_norm = [] |
| car_skip_linear = [] |
| for _ in range(depth // 2): |
| car_skip_norm.append(car_norm_layer(2 * self.C, elementwise_affine=True, eps=1e-6)) |
| car_skip_linear.append(nn.Linear(2 * self.C, self.C)) |
| self.car_skip_norm = nn.ModuleList(car_skip_norm) |
| self.car_skip_linear = nn.ModuleList(car_skip_linear) |
|
|
|
|
| def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False,x_BLC_lq=None, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| x_BLC_lq = x_BLC_lq.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) |
| x_BLC_list = list(torch.split(x_BLC,patch_nums_per_level,dim=1)) |
| x_BLC_lq_list_new = [] |
| CVae = x_BLC.shape[-1] |
| for si, (pn2, var_x) in enumerate(zip(patch_nums_per_level, x_BLC_list)): |
| pn = int(pn2**0.5) |
| var_x = var_x.transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) |
| var_x = self.car_var_conv(var_x) |
| car_x = x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) |
| car_x = self.car_control_convs(car_x) |
| car_x = var_x + car_x |
| car_x = car_x.view(B, CVae, -1).transpose(1, 2).contiguous() |
| x_BLC_lq_list_new.append(car_x) |
| x_BLC_lq = torch.cat(x_BLC_lq_list_new,dim=1) |
| |
| |
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| |
| x_BLC_lq = F.pad(x_BLC_lq, (0, 0, 0, need_to_pad)) |
| |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| x_BLC_lq = F.pad(x_BLC_lq,(0, 0, 0, need_to_pad)) |
| |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
|
|
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| |
| control_residual_f = [] |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.car_blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC_lq = torch.utils.checkpoint.checkpoint(b, x_BLC_lq, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC_lq = b(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| control_residual_f.append(x_BLC_lq) |
| else: |
| for i, chunk in enumerate(self.car_block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) |
| x_BLC_lq = chunk(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| control_residual_f.append(x_BLC_lq) |
| |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| |
| if i >= len(self.blocks) // 2: |
| con_f = control_residual_f.pop() |
| cat = torch.cat([x_BLC, con_f], dim=-1) |
| cat = self.car_skip_norm[i - len(self.blocks) // 2](cat) |
| x_BLC = self.car_skip_linear[i - len(self.blocks) // 2](cat) |
| |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| |
| if i >= len(self.block_chunks) // 2: |
| con_f = control_residual_f.pop() |
| cat = torch.cat([x_BLC, con_f], dim=-1) |
| cat = self.car_skip_norm[i - len(self.block_chunks) // 2](cat) |
| x_BLC = self.car_skip_linear[i - len(self.block_chunks) // 2](cat) |
| |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
|
|
| |
| return self.get_logits(x_BLC[:, :l_end], cond_BD) |
| |
| @torch.no_grad() |
| def forward_img_infinity(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], |
| cfg_infer=False, |
| **kwargs, |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| """ |
| label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) |
| :return: logits BLV, V is vocab_size |
| """ |
| if cfg_infer: |
| return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) |
| |
| x_BLC_wo_prefix = x_BLC_wo_prefix.float() |
| B = x_BLC_wo_prefix.shape[0] |
|
|
| |
| with torch.amp.autocast('cuda', enabled=False): |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| |
| total = 0 |
| for le in lens: |
| if random.random() < self.cond_drop_rate: |
| kv_compact[total:total+le] = self.cfg_uncond[:le] |
| total += le |
| must_on_graph = self.cfg_uncond[0, 0] * 0 |
| kv_compact = self.text_norm(kv_compact).contiguous() |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() |
| kv_compact = self.text_proj_for_ca(kv_compact).contiguous() |
| kv_compact[0, 0] += must_on_graph |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() |
| |
| sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) |
| x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) |
| |
| l_end = x_BLC.shape[1] |
| need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end |
| |
| if self.customized_flash_attn: |
| Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] |
| Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] |
| attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) |
| |
| elif self.use_flex_attn: |
| if need_to_pad: |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' |
| attn_bias_or_two_vector = None |
| else: |
| d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) |
| dT = d.transpose(1, 2) |
| attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) |
| attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() |
| if need_to_pad: |
| attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) |
| attn_bias[0, 0, l_end:, 0] = 0 |
| x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) |
| attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) |
| |
| if self.use_flex_attn: |
| attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] |
| else: |
| attn_fn = None |
|
|
| |
| SelfAttnBlock.forward, CrossAttnBlock.forward |
| checkpointing_full_block = self.checkpointing == 'full-block' and self.training |
| if self.num_block_chunks == 1: |
| for i, b in enumerate(self.blocks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if checkpointing_full_block: |
| x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) |
| else: |
| x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) |
| else: |
| for i, chunk in enumerate(self.block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) |
| x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) |
|
|
| |
| logit_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD) |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| return |
| |
| @torch.no_grad() |
| def car_inference( |
| self, |
| vae=None, |
| scale_schedule=None, |
| label_B_or_BLT=None, |
| B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, |
| g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, |
| returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, |
| cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], |
| vae_type=0, softmax_merge_topk=-1, ret_img=False, |
| trunk_scale=1000, |
| gt_leak=0, gt_ls_Bl=None, |
| inference_mode=False, |
| save_img_path=None, |
| sampling_per_bits=1, |
| x_BLC_lq=None |
| ): |
| if g_seed is None: rng = None |
| else: self.rng.manual_seed(g_seed); rng = self.rng |
| assert len(cfg_list) >= len(scale_schedule) |
| assert len(tau_list) >= len(scale_schedule) |
|
|
| |
| |
| if self.apply_spatial_patchify: |
| vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] |
| else: |
| vae_scale_schedule = scale_schedule |
| |
| kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT |
| if any(np.array(cfg_list) != 1): |
| bs = 2*B |
| if not negative_label_B_or_BLT: |
| kv_compact_un = kv_compact.clone() |
| total = 0 |
| print(f"kv_compact_un.shape {kv_compact_un.shape} self.cfg_uncond.shape {self.cfg_uncond.shape}") |
| print(lens) |
| |
| |
| |
| |
| |
| for le in lens: |
| kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] |
| total += le |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) |
| else: |
| kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT |
| kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) |
| cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) |
| max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) |
| else: |
| bs = B |
|
|
| kv_compact = self.text_norm(kv_compact) |
| sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) |
| kv_compact = self.text_proj_for_ca(kv_compact) |
| ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k |
| last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) |
| |
| x_BLC_lq = x_BLC_lq.expand(bs,-1,-1) |
| |
| |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() |
| accu_BChw, cur_L, ret = None, 0, [] |
| idx_Bl_list, idx_Bld_list = [], [] |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| if inference_mode: |
| for b in self.car_unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.car_block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) |
| |
| abs_cfg_insertion_layers = [] |
| add_cfg_on_logits, add_cfg_on_probs = False, False |
| leng = len(self.unregistered_blocks) |
| for item in cfg_insertion_layer: |
| if item == 0: |
| add_cfg_on_logits = True |
| elif item == 1: |
| add_cfg_on_probs = True |
| elif item < 0: |
| assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' |
| abs_cfg_insertion_layers.append(leng+item) |
| else: |
| raise ValueError(f'cfg_insertion_layer: {item} is not valid') |
| |
| num_stages_minus_1 = len(scale_schedule)-1 |
| summed_codes = 0 |
| |
| patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] |
| x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| for si, pn in enumerate(scale_schedule): |
| cfg = cfg_list[si] |
| if si >= trunk_scale: |
| break |
| cur_L += np.array(pn).prod() |
|
|
| need_to_pad = 0 |
| attn_fn = None |
| if self.use_flex_attn: |
| |
| |
| |
| attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) |
|
|
| |
| layer_idx = 0 |
| |
| control_residual_f = [] |
| if x_BLC_lq_list is not None: |
| last_stage_channel = last_stage.shape[-1] |
| control_x = x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(bs, last_stage_channel, pn[1], pn[2]) |
| control_x = self.car_control_convs(control_x) |
| var_x = last_stage.transpose(1, 2).contiguous().reshape(bs, last_stage_channel, pn[1], pn[2]) |
| var_x = self.car_var_conv(var_x) |
| control_x = var_x + control_x |
| control_x = control_x.view(bs, last_stage_channel, -1).transpose(1, 2) |
| |
| |
| |
| for i, chunk in enumerate(self.car_block_chunks): |
| if self.add_lvl_embeding_only_first_block and i == 0: |
| control_x = self.add_lvl_embeding(control_x, si, scale_schedule, need_to_pad=need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| control_x = self.add_lvl_embeding(control_x, si, scale_schedule, need_to_pad=need_to_pad) |
| for m in chunk.module: |
| control_x = m(x=control_x, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) |
| |
| |
| control_residual_f.append(control_x) |
| |
| |
| for block_idx, b in enumerate(self.block_chunks): |
| |
| if self.add_lvl_embeding_only_first_block and block_idx == 0: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| if not self.add_lvl_embeding_only_first_block: |
| last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) |
| |
| |
| if block_idx >= len(self.block_chunks) // 2: |
| con_f = control_residual_f.pop() |
| cat = torch.cat([last_stage, con_f], dim=-1) |
| cat = self.car_skip_norm[block_idx - len(self.block_chunks) // 2](cat) |
| last_stage = self.car_skip_linear[block_idx - len(self.block_chunks) // 2](cat) |
| |
| |
| for m in b.module: |
| |
| last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) |
| if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): |
| |
| last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] |
| last_stage = torch.cat((last_stage, last_stage), 0) |
| layer_idx += 1 |
| |
| if (cfg != 1) and add_cfg_on_logits: |
| |
| logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) |
| logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] |
| else: |
| logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) |
| |
| if self.use_bit_label: |
| tmp_bs, tmp_seq_len = logits_BlV.shape[:2] |
| logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) |
| idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) |
| else: |
| idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] |
| if vae_type != 0: |
| assert returns_vemb |
| if si < gt_leak: |
| idx_Bld = gt_ls_Bl[si] |
| else: |
| assert pn[0] == 1 |
| idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) |
| if self.apply_spatial_patchify: |
| idx_Bld = idx_Bld.permute(0,3,1,2) |
| idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) |
| idx_Bld = idx_Bld.permute(0,2,3,1) |
| idx_Bld = idx_Bld.unsqueeze(1) |
|
|
| idx_Bld_list.append(idx_Bld) |
| codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') |
| if si != num_stages_minus_1: |
| summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) |
| last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) |
| last_stage = last_stage.squeeze(-3) |
| if self.apply_spatial_patchify: |
| last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) |
| last_stage = last_stage.reshape(*last_stage.shape[:2], -1) |
| last_stage = torch.permute(last_stage, [0,2,1]) |
| else: |
| summed_codes += codes |
| else: |
| if si < gt_leak: |
| idx_Bl = gt_ls_Bl[si] |
| h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() |
|
|
| |
| h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) |
| ret.append(h_BChw if returns_vemb != 0 else idx_Bl) |
| idx_Bl_list.append(idx_Bl) |
| if si != num_stages_minus_1: |
| accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) |
| |
| if si != num_stages_minus_1: |
| last_stage = self.word_embed(self.norm0_ve(last_stage)) |
| last_stage = last_stage.repeat(bs//B, 1, 1) |
|
|
| if inference_mode: |
| for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
| |
| if inference_mode: |
| for b in self.car_unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) |
| else: |
| assert self.num_block_chunks > 1 |
| for block_chunk_ in self.car_block_chunks: |
| for module in block_chunk_.module.module: |
| (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) |
| |
|
|
| if not ret_img: |
| return ret, idx_Bl_list, [] |
| |
| if vae_type != 0: |
| img = vae.decode(summed_codes.squeeze(-3)) |
| else: |
| img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) |
|
|
| img = (img + 1) / 2 |
| img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) |
| return ret, idx_Bl_list, img |
|
|
| def sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: |
| B, l, V = logits_BlV.shape |
| if top_k > 0: |
| top_k = min(top_k, V) |
| idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True) |
| logits_BlV.masked_fill_(idx_to_remove, -torch.inf) |
| if top_p > 0: |
| sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False) |
| sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p) |
| sorted_idx_to_remove[..., -1:] = False |
| logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf) |
| |
| replacement = num_samples >= 0 |
| num_samples = abs(num_samples) |
| return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples) |
|
|
| def sample_with_top_k_top_p_also_modifying_logits_( |
| logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1 |
| ) -> torch.Tensor: |
| B, l, V = logits_BlV.shape |
| logits_BlV = logits_BlV.clone() |
|
|
| if top_k > 0: |
| top_k = min(top_k, V) |
| idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True) |
| logits_BlV = logits_BlV.masked_fill(idx_to_remove, -torch.inf) |
|
|
| if top_p > 0: |
| sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False) |
| sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) <= (1 - top_p) |
| sorted_idx_to_remove[..., -1:] = False |
| logits_BlV = logits_BlV.masked_fill( |
| sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf |
| ) |
|
|
| replacement = num_samples >= 0 |
| num_samples = abs(num_samples) |
| return torch.multinomial( |
| logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng |
| ).view(B, l, num_samples) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| def sampling_with_top_k_top_p_also_inplace_modifying_probs_(probs_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: |
| B, l, V = probs_BlV.shape |
| if top_k > 0: |
| top_k = min(top_k, V) |
| idx_to_remove = probs_BlV < probs_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True) |
| probs_BlV.masked_fill_(idx_to_remove, 0) |
| if top_p > 0: |
| sorted_probs, sorted_idx = probs_BlV.sort(dim=-1, descending=False) |
| sorted_idx_to_remove = sorted_probs.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p) |
| sorted_idx_to_remove[..., -1:] = False |
| probs_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), 0) |
| |
| probs_BlV = probs_BlV / probs_BlV.sum(-1, keepdims=True) |
| replacement = num_samples >= 0 |
| num_samples = abs(num_samples) |
| return torch.multinomial(probs_BlV.view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples) |
|
|
|
|
| def get_params_num(d, w, mlp): |
| m = round(mlp * w / 256) * 256 |
| s = d * (w**2 * 8 + w*m * 2) |
| s += w**2 * 6 |
| s += 4096 * w |
| s += 32 * w |
| |
| Ct5 = 4096 |
| s += Ct5*w * 4 |
| s += Ct5*w + w*w |
| return f'{s/1e9:.2f}B' |
|
|
|
|
| TIMM_KEYS = {'img_size', 'pretrained', 'pretrained_cfg', 'pretrained_cfg_overlay', 'global_pool'} |
|
|
| @register_model |
| def infinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
|
|
| @register_model |
| def cinfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return CInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
|
|
| @register_model |
| def finfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return FInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
|
|
| @register_model |
| def fainfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return FAInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
|
|
| @register_model |
| def ainfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return AInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
|
|
| @register_model |
| def binfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return BInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
|
|
| @register_model |
| def infinity_20b(depth=58, embed_dim=4608, num_heads=4608//128, drop_path_rate=0.25, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
|
|
| |
| @register_model |
| def infinity_layer12(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, **kwargs): |
| return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
| @register_model |
| def infinity_layer16(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, **kwargs): |
| return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
| @register_model |
| def infinity_layer24(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, **kwargs): |
| return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
| @register_model |
| def infinity_layer32(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, **kwargs): |
| return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
| @register_model |
| def infinity_layer40(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, **kwargs): |
| return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
| @register_model |
| def infinity_layer48(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, **kwargs): |
| return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) |
|
|