diff --git "a/infinity/models/infinity.py" "b/infinity/models/infinity.py" new file mode 100644--- /dev/null +++ "b/infinity/models/infinity.py" @@ -0,0 +1,6315 @@ +""" +Definition of Infinity transformer model. +""" + +import math +import random +import time +from contextlib import nullcontext +from functools import partial +from typing import List, Optional, Tuple, Union, Dict, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models import register_model +from torch.utils.checkpoint import checkpoint +from PIL import Image +import numpy as np +from torch.nn.attention.flex_attention import flex_attention + +import infinity.utils.dist as dist + +from infinity.utils.dist import for_visualize +from infinity.models.basic import flash_attn_func, flash_fused_op_installed, AdaLNBeforeHead, CrossAttnBlock, SelfAttnBlock, CrossAttention, FastRMSNorm, precompute_rope2d_freqs_grid +from infinity.utils import misc +from infinity.models.flex_attn import FlexAttn +from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates +try: + from infinity.models.fused_op import fused_ada_layer_norm, fused_ada_rms_norm +except: + fused_ada_layer_norm, fused_ada_rms_norm = None, None + +import pdb + +class MultiInpIdentity(nn.Module): + def forward(self, x, *args, **kwargs): + return x + + +class TextAttentivePool(nn.Module): + def __init__(self, Ct5: int, D: int): + super().__init__() + self.Ct5, self.D = Ct5, D + if D > 4096: + self.head_dim = 64 + else: + self.head_dim = 128 + + self.num_heads = Ct5 // self.head_dim + self.ca = CrossAttention(for_attn_pool=True, embed_dim=self.D, kv_dim=Ct5, num_heads=self.num_heads) + def forward(self, ca_kv): + return self.ca(None, ca_kv).squeeze(1) + +class SharedAdaLin(nn.Linear): + def forward(self, cond_BD): + C = self.weight.shape[0] // 6 + return super().forward(cond_BD).reshape(-1, 1, 6, C) # B16C + + +class MultipleLayers(nn.Module): + def __init__(self, ls, num_blocks_in_a_chunk, index): + super().__init__() + self.module = nn.ModuleList() + for i in range(index, index+num_blocks_in_a_chunk): + self.module.append(ls[i]) + + def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None): + h = x + for m in self.module: + if checkpointing_full_block: + h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False) + else: + h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid) + return h + + def forward_fsdp(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None): + h = x + for m in self.module.module: + if checkpointing_full_block: + h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False) + else: + h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid) + return h + +class STGumbelArgmax(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, tau=1.0): + # 前向传播:生成近似 one-hot 向量 [B, L, C] + y_hard = F.gumbel_softmax(logits, tau=tau, hard=True) + # 保存中间结果用于反向传播 + ctx.save_for_backward(logits) + ctx.tau = tau + return y_hard + + @staticmethod + def backward(ctx, grad_output): + # 反向传播:用 Gumbel-Softmax 的软概率梯度替代硬 one-hot 的梯度 + logits, = ctx.saved_tensors + tau = ctx.tau + # 计算软概率的梯度 + soft_grad = F.gumbel_softmax(logits, tau=tau, hard=False) + grad_input = grad_output * soft_grad + return grad_input, None + +def GumbelArgmax(logits,tau): + + U = torch.rand(logits.shape).to(logits.device) + eps = 1e-20 + + gumbel_noise = -torch.log(-torch.log(U + eps) + eps) + perturbed_logits = (logits + gumbel_noise) / tau + + y_soft = F.softmax(perturbed_logits, dim=-1) + index = y_soft.max(dim=-1, keepdim=True)[1] + y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0) + + return y_hard - y_soft.detach() + y_soft + +class Infinity(nn.Module): + def __init__( + self, vae_local, + text_channels=0, text_maxlen=0, # text-cond generation + selecting_idx=None, # class-cond generation + embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture + drop_rate=0., drop_path_rate=0., # drop out and drop path + norm_eps=1e-6, rms_norm=False, # norm layer + shared_aln=False, head_aln=True, # adaptive norm + cond_drop_rate=0.1, # for classifier-free guidance + rand_uncond=False, + cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, + raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), + head_depth=1, + top_p=0.0, top_k=0.0, + customized_flash_attn=False, fused_mlp=False, fused_norm=False, + block_chunks=1, + checkpointing=None, + pad_to_multiplier=0, + use_flex_attn=False, + batch_size=2, + add_lvl_embeding_only_first_block=1, + use_bit_label=1, + rope2d_each_sa_layer=0, + rope2d_normalized_by_hw=0, + pn=None, + train_h_div_w_list=None, + video_frames=1, + always_training_scales=20, + apply_spatial_patchify = 0, + inference_mode=False, + ): + # set hyperparameters + self.C = embed_dim + self.inference_mode = inference_mode + self.apply_spatial_patchify = apply_spatial_patchify + if self.apply_spatial_patchify: + self.d_vae = vae_local.embed_dim * 4 + else: + self.d_vae = vae_local.embed_dim + self.use_bit_label = use_bit_label + self.codebook_dim = self.d_vae + self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size + self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None + self.Ct5 = text_channels + self.depth = depth + self.num_heads = num_heads + self.batch_size = batch_size + self.mlp_ratio = mlp_ratio + self.cond_drop_rate = cond_drop_rate + self.norm_eps = norm_eps + self.prog_si = -1 + self.pn = pn + self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates + self.video_frames = video_frames + self.always_training_scales = always_training_scales + + assert add_lvl_embeding_only_first_block in [0,1] + self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block + assert rope2d_each_sa_layer in [0,1] + self.rope2d_each_sa_layer = rope2d_each_sa_layer + self.rope2d_normalized_by_hw = rope2d_normalized_by_hw + print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ + self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') + head_up_method = '' + word_patch_size = 1 if head_up_method in {'', 'no'} else 2 + if word_patch_size > 1: + assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' + + self.checkpointing = checkpointing + self.pad_to_multiplier = max(1, pad_to_multiplier) + + customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) + self.customized_flash_attn = customized_flash_attn and customized_kernel_installed + if customized_flash_attn and not customized_kernel_installed: + import inspect, warnings + file_path = inspect.getsourcefile(flash_attn_func) + line_number = inspect.getsourcelines(flash_attn_func)[1] + info = ( + f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' + f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' + f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' + ) + warnings.warn(info, ImportWarning) + print(info, flush=True) + + self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying + self.first_l = 1 + # solve top-p top-k sampling hyperparameters + self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) + if self.top_p < 1e-5: self.top_p = 0 + if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 + + t = torch.zeros(dist.get_world_size(), device=dist.get_device()) + t[dist.get_rank()] = float(flash_fused_op_installed) + dist.barrier() + dist.allreduce(t) + assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' + + super().__init__() + self.rng = torch.Generator(device=dist.get_device()) + self.maybe_record_function = nullcontext + self.text_maxlen = text_maxlen + self.t2i = text_channels != 0 + + # [inp & position embedding] + init_std = math.sqrt(1 / self.C / 3) + self.norm0_cond = nn.Identity() + if self.t2i: + self.selecting_idx = None + self.num_classes = 0 + self.D = self.C + + cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) + rng = torch.Generator(device='cpu') + rng.manual_seed(0) + torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) + cfg_uncond /= self.Ct5 ** 0.5 + if rand_uncond: + self.register_buffer('cfg_uncond', cfg_uncond) + else: + self.cfg_uncond = nn.Parameter(cfg_uncond) + + self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) + self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) + self.text_proj_for_ca = nn.Sequential( + nn.Linear(self.Ct5, self.D), + nn.GELU(approximate='tanh'), + nn.Linear(self.D, self.D), + ) + else: # class-label cond + if selecting_idx is None: + num_classes = 1000 + print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') + selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) + self.selecting_idx = selecting_idx + self.num_classes = selecting_idx.shape[-1] + self.D = self.C + self.class_emb = nn.Embedding(self.num_classes + 1, self.C) + nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) + + self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) + nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) + if self.rope2d_each_sa_layer: + rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) + self.rope2d_freqs_grid = rope2d_freqs_grid + else: + raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') + self.lvl_embed = nn.Embedding(15, self.C) + nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) + + # [input layers] input norm && input embedding + norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) + self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() + self.word_embed = nn.Linear(self.d_vae, self.C) + + # [shared adaptive layernorm mapping network] + self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() + + # fused norm + if fused_norm: + fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm + if fused_norm_func is not None: # pre-compile + B = 2 + x = torch.randn(B, 1, self.C).requires_grad_(True) + scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward() + del B, x, scale, shift + else: + fused_norm_func = None + + # [backbone and head] + self.use_flex_attn = use_flex_attn + self.attn_fn_compile_dict = {} + self.batch_size = batch_size + if self.use_flex_attn: + self.attn_fn_compile_dict = self.compile_flex_attn() + + self.drop_path_rate = drop_path_rate + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing) + self.unregistered_blocks = [] + for block_idx in range(depth): + block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( + embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, + num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, + swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, + checkpointing_sa_only=self.checkpointing == 'self-attn', + use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, + ) + self.unregistered_blocks.append(block) + + # [head] + V = self.V + if head_aln: + self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) + self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + else: + self.head_nm = MultiInpIdentity() + self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + + self.num_block_chunks = block_chunks or 1 + self.num_blocks_in_a_chunk = depth // block_chunks + print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") + assert self.num_blocks_in_a_chunk * block_chunks == depth + if self.num_block_chunks == 1: + self.blocks = nn.ModuleList(self.unregistered_blocks) + else: + self.block_chunks = nn.ModuleList() + for i in range(self.num_block_chunks): + self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) + print( + f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' + f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' + f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', + end='\n\n', flush=True + ) + + + def compile_flex_attn(self): + attn_fn_compile_dict = {} + for h_div_w in self.train_h_div_w_list: + h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] + full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] + if self.inference_mode: + apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) + mask_type = "infinity_infer_mask_with_kv_cache" + auto_padding = True + else: + mask_type = 'var' + auto_padding = False + apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] + for scales_num in apply_flex_attn_scales: + print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') + scale_schedule = full_scale_schedule[:scales_num] + scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L, + auto_padding=auto_padding) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + + if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos) + scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + return attn_fn_compile_dict + + def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): + """ + :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) + :param cond_BD: shaped (B or batch_size, D or cond_dim) + :param tau: temperature + :return: logits, shaped (B or batch_size, V or vocabulary_size) + """ + with torch.amp.autocast('cuda', enabled=False): + return self.head(self.head_nm(h.float(), cond_BD.float())) + + def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): + bs, seq_len, c = feature.shape + patch_t, patch_h, patch_w = scale_schedule[scale_ind] + t_mul_h_mul_w = patch_t * patch_h * patch_w + assert t_mul_h_mul_w + need_to_pad == seq_len + feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) + return feature + + def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): + ptr = 0 + x_BLC_list = [] + for scale_ind, patch_t_h_w in enumerate(scale_schedule): + scale_seq_len = np.array(patch_t_h_w).prod() + x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c] + ptr += scale_seq_len + x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) + x_BLC_list.append(x_BLC_this_scale) + assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' + x_BLC_list.append(x_BLC[:,ptr:]) + x_BLC = torch.cat(x_BLC_list, dim=1) + return x_BLC + + def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + + # [3. unpad the seqlen dim, and then get logits] + return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + @torch.no_grad() + def autoregressive_infer_cfg( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + for si, pn in enumerate(scale_schedule): # si: i-th segment + cfg = cfg_list[si] + if si >= trunk_scale: + break + cur_L += np.array(pn).prod() + + need_to_pad = 0 + attn_fn = None + if self.use_flex_attn: + # need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier + # if need_to_pad: + # last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad)) + attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) + + # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item' + layer_idx = 0 + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + if not self.add_lvl_embeding_only_first_block: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + for m in b.module: + last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] + last_stage = torch.cat((last_stage, last_stage), 0) + layer_idx += 1 + + if (cfg != 1) and add_cfg_on_logits: + # print(f'add cfg on add_cfg_on_logits') + logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + if vae_type != 0: + assert returns_vemb + if si < gt_leak: + idx_Bld = gt_ls_Bl[si] + else: + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + idx_Bld_list.append(idx_Bld) + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w] + if self.apply_spatial_patchify: # patchify operation + last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w] + last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w] + last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d] + else: + summed_codes += codes + else: + if si < gt_leak: + idx_Bl = gt_ls_Bl[si] + h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC + + # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1]) + h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) + ret.append(h_BChw if returns_vemb != 0 else idx_Bl) + idx_Bl_list.append(idx_Bl) + if si != num_stages_minus_1: + accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) + + if si != num_stages_minus_1: + last_stage = self.word_embed(self.norm0_ve(last_stage)) + last_stage = last_stage.repeat(bs//B, 1, 1) + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + if not ret_img: + return ret, idx_Bl_list, [] + + if vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + else: + img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return ret, idx_Bl_list, img + + @torch.no_grad() + def autoregressive_infer_cfg_w_lq_token( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + x_BLC_wo_prefix_lq=None, + gt_BL_list=[], + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + + x_BLC_wo_prefix_lq = x_BLC_wo_prefix_lq.float() + x_BLC_wo_prefix_lq = x_BLC_wo_prefix_lq.expand(bs,-1,-1) + x_BLC_lq = torch.cat((last_stage, self.word_embed(self.norm0_ve(x_BLC_wo_prefix_lq))), dim=1) + + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + logits_BlV_list = [] + for si, pn in enumerate(scale_schedule): # si: i-th segment + cfg = cfg_list[si] + if si >= trunk_scale: + break + cur_L += np.array(pn).prod() + + ###### + if si < 10: + last_stage = x_BLC_lq_list[si] + ###### + + need_to_pad = 0 + attn_fn = None + if self.use_flex_attn: + # need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier + # if need_to_pad: + # last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad)) + attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) + + # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item' + layer_idx = 0 + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + if not self.add_lvl_embeding_only_first_block: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + for m in b.module: + last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] + last_stage = torch.cat((last_stage, last_stage), 0) + layer_idx += 1 + + if (cfg != 1) and add_cfg_on_logits: + # print(f'add cfg on add_cfg_on_logits') + logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + + logits_BlV_list.append(logits_BlV) + + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + # ##### + # idx_Bld = GumbelArgmax(logits_BlV, 0.5) + # tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + # tmp_tensor[:,:,1:]=1 + # idx_Bld = idx_Bld * tmp_tensor + # idx_Bld = idx_Bld.sum(dim=-1) + # ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + + ### + if si < len(gt_BL_list): + idx_Bld = gt_BL_list[si] + ### + else: + idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + if vae_type != 0: + assert returns_vemb + if si < gt_leak: + idx_Bld = gt_ls_Bl[si] + else: + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + idx_Bld_list.append(idx_Bld) + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w] + if self.apply_spatial_patchify: # patchify operation + last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w] + last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w] + last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d] + else: + summed_codes += codes + else: + if si < gt_leak: + idx_Bl = gt_ls_Bl[si] + h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC + + # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1]) + h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) + ret.append(h_BChw if returns_vemb != 0 else idx_Bl) + idx_Bl_list.append(idx_Bl) + if si != num_stages_minus_1: + accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) + + if si != num_stages_minus_1: + last_stage = self.word_embed(self.norm0_ve(last_stage)) + last_stage = last_stage.repeat(bs//B, 1, 1) + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + if not ret_img: + return ret, idx_Bl_list, [] + + if vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + else: + img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + + logits_BlV_all = torch.cat(logits_BlV_list,dim = 1) + return ret, idx_Bl_list, img, logits_BlV_all + + def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): + # logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) + + B = logits_BlV_all.shape[0] + + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + summed_codes = 0 + num_stages_minus_1 = len(scale_schedule)-1 + + for si,logits_BlV in enumerate(logits_BlV_list): + pn= scale_schedule[si] + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + #idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + # ##### + # idx_Bld = STGumbelArgmax.apply(logits_BlV, 0.5) + # tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + # tmp_tensor[:,:,1:]=1 + # idx_Bld = idx_Bld * tmp_tensor + # idx_Bld = idx_Bld.sum(dim=-1) + # ##### + + ##### + idx_Bld = GumbelArgmax(logits_BlV, 0.5) + tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + tmp_tensor[:,:,1:]=1 + idx_Bld = idx_Bld * tmp_tensor + idx_Bld = idx_Bld.sum(dim=-1) + ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### vae_type!=0 + ###si>=gt_leak + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + else: + summed_codes += codes + + # if inference_mode: + # for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + # else: + # assert self.num_block_chunks > 1 + # for block_chunk_ in self.block_chunks: + # for module in block_chunk_.module.module: + # (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + #vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return img + + def forward_teacher(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + + # [3. unpad the seqlen dim, and then get logits] + return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + + @for_visualize + def vis_key_params(self, ep): + return + + def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): + for k in state_dict: + if 'cfg_uncond' in k: + old, new = state_dict[k], self.cfg_uncond.data + min_tlen = min(old.shape[0], new.shape[0]) + if min_tlen == old.shape[0]: + state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) + else: + state_dict[k] = old[:min_tlen] + + for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): + state_dict.pop(buf_name, None) + if hasattr(self, buf_name): + state_dict[buf_name] = getattr(self, buf_name) + + return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) + + def special_init( + self, + aln_init: float, + aln_gamma_init: float, + scale_head: float, + scale_proj: int, + ): + # init head's norm + if isinstance(self.head_nm, AdaLNBeforeHead): + self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head + if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: + self.head_nm.ada_lin[-1].bias.data.zero_() + + # init head's proj + if scale_head >= 0: + if isinstance(self.head, nn.Linear): + self.head.weight.data.mul_(scale_head) + self.head.bias.data.zero_() + elif isinstance(self.head, nn.Sequential): + self.head[-1].weight.data.mul_(scale_head) + self.head[-1].bias.data.zero_() + + depth = len(self.unregistered_blocks) + for block_idx, sab in enumerate(self.unregistered_blocks): + sab: Union[SelfAttnBlock, CrossAttnBlock] + # init proj + scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) + if scale_proj == 1: + if self.t2i: + sab.sa.proj.weight.data.mul_(scale) + sab.ca.proj.weight.data.mul_(scale) + else: + sab.attn.proj.weight.data.mul_(scale) + sab.ffn.fc2.weight.data.mul_(scale) + # if sab.using_swiglu: + # nn.init.ones_(sab.ffn.fcg.bias) + # nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5) + + # init ada_lin + if hasattr(sab, 'ada_lin'): + lin = sab.ada_lin[-1] + lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma + lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift + if hasattr(lin, 'bias') and lin.bias is not None: + lin.bias.data.zero_() + elif hasattr(sab, 'ada_gss'): + sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma + sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift + + def extra_repr(self): + return f'drop_path_rate={self.drop_path_rate}' + + def get_layer_id_and_scale_exp(self, para_name: str): + raise NotImplementedError + +class BInfinity(nn.Module): ###backbone + def __init__( + self, vae_local, + text_channels=0, text_maxlen=0, # text-cond generation + selecting_idx=None, # class-cond generation + embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture + drop_rate=0., drop_path_rate=0., # drop out and drop path + norm_eps=1e-6, rms_norm=False, # norm layer + shared_aln=False, head_aln=True, # adaptive norm + cond_drop_rate=0.1, # for classifier-free guidance + rand_uncond=False, + cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, + raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), + head_depth=1, + top_p=0.0, top_k=0.0, + customized_flash_attn=False, fused_mlp=False, fused_norm=False, + block_chunks=1, + checkpointing=None, + pad_to_multiplier=0, + use_flex_attn=False, + batch_size=2, + add_lvl_embeding_only_first_block=1, + use_bit_label=1, + rope2d_each_sa_layer=0, + rope2d_normalized_by_hw=0, + pn=None, + train_h_div_w_list=None, + video_frames=1, + always_training_scales=20, + apply_spatial_patchify = 0, + inference_mode=False, + ): + # set hyperparameters + self.C = embed_dim + self.inference_mode = inference_mode + self.apply_spatial_patchify = apply_spatial_patchify + if self.apply_spatial_patchify: + self.d_vae = vae_local.embed_dim * 4 + else: + self.d_vae = vae_local.embed_dim + self.use_bit_label = use_bit_label + self.codebook_dim = self.d_vae + self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size + self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None + self.Ct5 = text_channels + self.depth = depth + self.num_heads = num_heads + self.batch_size = batch_size + self.mlp_ratio = mlp_ratio + self.cond_drop_rate = cond_drop_rate + self.norm_eps = norm_eps + self.prog_si = -1 + self.pn = pn + self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates + self.video_frames = video_frames + self.always_training_scales = always_training_scales + + assert add_lvl_embeding_only_first_block in [0,1] + self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block + assert rope2d_each_sa_layer in [0,1] + self.rope2d_each_sa_layer = rope2d_each_sa_layer + self.rope2d_normalized_by_hw = rope2d_normalized_by_hw + print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ + self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') + head_up_method = '' + word_patch_size = 1 if head_up_method in {'', 'no'} else 2 + if word_patch_size > 1: + assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' + + self.checkpointing = checkpointing + self.pad_to_multiplier = max(1, pad_to_multiplier) + + customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) + self.customized_flash_attn = customized_flash_attn and customized_kernel_installed + if customized_flash_attn and not customized_kernel_installed: + import inspect, warnings + file_path = inspect.getsourcefile(flash_attn_func) + line_number = inspect.getsourcelines(flash_attn_func)[1] + info = ( + f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' + f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' + f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' + ) + warnings.warn(info, ImportWarning) + print(info, flush=True) + + self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying + self.first_l = 1 + # solve top-p top-k sampling hyperparameters + self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) + if self.top_p < 1e-5: self.top_p = 0 + if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 + + t = torch.zeros(dist.get_world_size(), device=dist.get_device()) + t[dist.get_rank()] = float(flash_fused_op_installed) + dist.barrier() + dist.allreduce(t) + assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' + + super().__init__() + self.rng = torch.Generator(device=dist.get_device()) + self.maybe_record_function = nullcontext + self.text_maxlen = text_maxlen + self.t2i = text_channels != 0 + + # [inp & position embedding] + init_std = math.sqrt(1 / self.C / 3) + self.norm0_cond = nn.Identity() + if self.t2i: + self.selecting_idx = None + self.num_classes = 0 + self.D = self.C + + cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) + rng = torch.Generator(device='cpu') + rng.manual_seed(0) + torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) + cfg_uncond /= self.Ct5 ** 0.5 + if rand_uncond: + self.register_buffer('cfg_uncond', cfg_uncond) + else: + self.cfg_uncond = nn.Parameter(cfg_uncond) + + self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) + self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) + self.text_proj_for_ca = nn.Sequential( + nn.Linear(self.Ct5, self.D), + nn.GELU(approximate='tanh'), + nn.Linear(self.D, self.D), + ) + else: # class-label cond + if selecting_idx is None: + num_classes = 1000 + print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') + selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) + self.selecting_idx = selecting_idx + self.num_classes = selecting_idx.shape[-1] + self.D = self.C + self.class_emb = nn.Embedding(self.num_classes + 1, self.C) + nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) + + self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) + nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) + if self.rope2d_each_sa_layer: + rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) + self.rope2d_freqs_grid = rope2d_freqs_grid + else: + raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') + self.lvl_embed = nn.Embedding(15, self.C) + nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) + + # [input layers] input norm && input embedding + norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) + self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() + self.word_embed = nn.Linear(self.d_vae, self.C) + + # [shared adaptive layernorm mapping network] + self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() + + # fused norm + if fused_norm: + fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm + if fused_norm_func is not None: # pre-compile + B = 2 + x = torch.randn(B, 1, self.C).requires_grad_(True) + scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward() + del B, x, scale, shift + else: + fused_norm_func = None + + # [backbone and head] + self.use_flex_attn = use_flex_attn + self.attn_fn_compile_dict = {} + self.batch_size = batch_size + if self.use_flex_attn: + self.attn_fn_compile_dict = self.compile_flex_attn() + + self.drop_path_rate = drop_path_rate + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing) + self.unregistered_blocks = [] + for block_idx in range(depth): + block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( + embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, + num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, + swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, + checkpointing_sa_only=self.checkpointing == 'self-attn', + use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, + ) + self.unregistered_blocks.append(block) + + # [head] + V = self.V + if head_aln: + self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) + self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + else: + self.head_nm = MultiInpIdentity() + self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + + self.num_block_chunks = block_chunks or 1 + self.num_blocks_in_a_chunk = depth // block_chunks + print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") + assert self.num_blocks_in_a_chunk * block_chunks == depth + if self.num_block_chunks == 1: + self.blocks = nn.ModuleList(self.unregistered_blocks) + else: + self.block_chunks = nn.ModuleList() + for i in range(self.num_block_chunks): + self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) + print( + f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' + f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' + f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', + end='\n\n', flush=True + ) + + def compile_flex_attn(self): + attn_fn_compile_dict = {} + for h_div_w in self.train_h_div_w_list: + h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] + full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] + if self.inference_mode: + apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) + mask_type = "infinity_infer_mask_with_kv_cache" + auto_padding = True + else: + mask_type = 'var' + auto_padding = False + apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] + for scales_num in apply_flex_attn_scales: + print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') + scale_schedule = full_scale_schedule[:scales_num] + scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L, + auto_padding=auto_padding) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + + if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos) + scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + return attn_fn_compile_dict + + def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): + """ + :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) + :param cond_BD: shaped (B or batch_size, D or cond_dim) + :param tau: temperature + :return: logits, shaped (B or batch_size, V or vocabulary_size) + """ + with torch.amp.autocast('cuda', enabled=False): + return self.head(self.head_nm(h.float(), cond_BD.float())) + + def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): + bs, seq_len, c = feature.shape + patch_t, patch_h, patch_w = scale_schedule[scale_ind] + t_mul_h_mul_w = patch_t * patch_h * patch_w + assert t_mul_h_mul_w + need_to_pad == seq_len + feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) + return feature + + def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): + ptr = 0 + x_BLC_list = [] + for scale_ind, patch_t_h_w in enumerate(scale_schedule): + scale_seq_len = np.array(patch_t_h_w).prod() + x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c] + ptr += scale_seq_len + x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) + x_BLC_list.append(x_BLC_this_scale) + assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' + x_BLC_list.append(x_BLC[:,ptr:]) + x_BLC = torch.cat(x_BLC_list, dim=1) + return x_BLC + + def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + + all_scale_length = np.sum([np.prod(scale_schedule[j]) for j in range(len(scale_schedule))]) + last_scale_length = np.prod(scale_schedule[-1]) + if x_BLC.shape[1] == all_scale_length: + long_input = 0 + else: + assert x_BLC.shape[1] == all_scale_length + last_scale_length + long_input = 1 + + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + # if self.customized_flash_attn: + # Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + # Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + # attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # # todo: solve need_to_pad here + # elif self.use_flex_attn: + # if need_to_pad: + # x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + # assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + # attn_bias_or_two_vector = None + # else: + # d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + # dT = d.transpose(1, 2) # dT: 11L + # attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + # attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + # if need_to_pad: + # attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + # attn_bias[0, 0, l_end:, 0] = 0 + # x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + # attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + if long_input == 1: + scale_schedule_new = scale_schedule + [scale_schedule[-1]] + else: + scale_schedule_new = scale_schedule + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule_new, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule_new, need_to_pad) + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, None, None, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule_new , need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule_new, need_to_pad) + #chunk scale_schedule is for selecting rotatory embedding + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + #x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + # [3. unpad the seqlen dim, and then get logits] + + + #return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + return self.get_logits(x_BLC[:, :all_scale_length], cond_BD) # return logits BLV, V is vocab_size + + def forward_teacher(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + + # [3. unpad the seqlen dim, and then get logits] + return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): + # logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) + + B = logits_BlV_all.shape[0] + + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + summed_codes = 0 + num_stages_minus_1 = len(scale_schedule)-1 + + for si,logits_BlV in enumerate(logits_BlV_list): + pn= scale_schedule[si] + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + #idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + # ##### + # idx_Bld = STGumbelArgmax.apply(logits_BlV, 0.5) + # tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + # tmp_tensor[:,:,1:]=1 + # idx_Bld = idx_Bld * tmp_tensor + # idx_Bld = idx_Bld.sum(dim=-1) + # ##### + + ##### + idx_Bld = GumbelArgmax(logits_BlV, 0.5) + tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + tmp_tensor[:,:,1:]=1 + idx_Bld = idx_Bld * tmp_tensor + idx_Bld = idx_Bld.sum(dim=-1) + ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### vae_type!=0 + ###si>=gt_leak + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + else: + summed_codes += codes + + # if inference_mode: + # for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + # else: + # assert self.num_block_chunks > 1 + # for block_chunk_ in self.block_chunks: + # for module in block_chunk_.module.module: + # (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + #vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return img + + @torch.no_grad() + def autoregressive_infer_cfg( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + x_BLC_wo_prefix_lq=None + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + # assert len(cfg_list) >= len(scale_schedule) + # assert len(tau_list) >= len(scale_schedule) + + logits_BlV = self.forward(label_B_or_BLT,x_BLC_wo_prefix_lq,scale_schedule) + + img = self.logits_to_img(logits_BlV_all=logits_BlV, + vae=vae, + scale_schedule=scale_schedule, + top_k=top_k, + top_p=top_p, + g_seed=g_seed) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + + return None,None,img + + + @for_visualize + def vis_key_params(self, ep): + return + + def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): + for k in state_dict: + if 'cfg_uncond' in k: + old, new = state_dict[k], self.cfg_uncond.data + min_tlen = min(old.shape[0], new.shape[0]) + if min_tlen == old.shape[0]: + state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) + else: + state_dict[k] = old[:min_tlen] + + for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): + state_dict.pop(buf_name, None) + if hasattr(self, buf_name): + state_dict[buf_name] = getattr(self, buf_name) + + return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) + + def special_init( + self, + aln_init: float, + aln_gamma_init: float, + scale_head: float, + scale_proj: int, + ): + # init head's norm + if isinstance(self.head_nm, AdaLNBeforeHead): + self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head + if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: + self.head_nm.ada_lin[-1].bias.data.zero_() + + # init head's proj + if scale_head >= 0: + if isinstance(self.head, nn.Linear): + self.head.weight.data.mul_(scale_head) + self.head.bias.data.zero_() + elif isinstance(self.head, nn.Sequential): + self.head[-1].weight.data.mul_(scale_head) + self.head[-1].bias.data.zero_() + + depth = len(self.unregistered_blocks) + for block_idx, sab in enumerate(self.unregistered_blocks): + sab: Union[SelfAttnBlock, CrossAttnBlock] + # init proj + scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) + if scale_proj == 1: + if self.t2i: + sab.sa.proj.weight.data.mul_(scale) + sab.ca.proj.weight.data.mul_(scale) + else: + sab.attn.proj.weight.data.mul_(scale) + sab.ffn.fc2.weight.data.mul_(scale) + # if sab.using_swiglu: + # nn.init.ones_(sab.ffn.fcg.bias) + # nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5) + + # init ada_lin + if hasattr(sab, 'ada_lin'): + lin = sab.ada_lin[-1] + lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma + lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift + if hasattr(lin, 'bias') and lin.bias is not None: + lin.bias.data.zero_() + elif hasattr(sab, 'ada_gss'): + sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma + sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift + + def extra_repr(self): + return f'drop_path_rate={self.drop_path_rate}' + + def get_layer_id_and_scale_exp(self, para_name: str): + raise NotImplementedError + +class AInfinity(nn.Module): # x_BLC add x_BLC_lq + def __init__( + self, vae_local, + text_channels=0, text_maxlen=0, # text-cond generation + selecting_idx=None, # class-cond generation + embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture + drop_rate=0., drop_path_rate=0., # drop out and drop path + norm_eps=1e-6, rms_norm=False, # norm layer + shared_aln=False, head_aln=True, # adaptive norm + cond_drop_rate=0.1, # for classifier-free guidance + rand_uncond=False, + cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, + raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), + head_depth=1, + top_p=0.0, top_k=0.0, + customized_flash_attn=False, fused_mlp=False, fused_norm=False, + block_chunks=1, + checkpointing=None, + pad_to_multiplier=0, + use_flex_attn=False, + batch_size=2, + add_lvl_embeding_only_first_block=1, + use_bit_label=1, + rope2d_each_sa_layer=0, + rope2d_normalized_by_hw=0, + pn=None, + train_h_div_w_list=None, + video_frames=1, + always_training_scales=20, + apply_spatial_patchify = 0, + inference_mode=False, + ): + # set hyperparameters + self.C = embed_dim + self.inference_mode = inference_mode + self.apply_spatial_patchify = apply_spatial_patchify + if self.apply_spatial_patchify: + self.d_vae = vae_local.embed_dim * 4 + else: + self.d_vae = vae_local.embed_dim + self.use_bit_label = use_bit_label + self.codebook_dim = self.d_vae + self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size + self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None + self.Ct5 = text_channels + self.depth = depth + self.num_heads = num_heads + self.batch_size = batch_size + self.mlp_ratio = mlp_ratio + self.cond_drop_rate = cond_drop_rate + self.norm_eps = norm_eps + self.prog_si = -1 + self.pn = pn + self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates + self.video_frames = video_frames + self.always_training_scales = always_training_scales + + assert add_lvl_embeding_only_first_block in [0,1] + self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block + assert rope2d_each_sa_layer in [0,1] + self.rope2d_each_sa_layer = rope2d_each_sa_layer + self.rope2d_normalized_by_hw = rope2d_normalized_by_hw + print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ + self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') + head_up_method = '' + word_patch_size = 1 if head_up_method in {'', 'no'} else 2 + if word_patch_size > 1: + assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' + + self.checkpointing = checkpointing + self.pad_to_multiplier = max(1, pad_to_multiplier) + + customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) + self.customized_flash_attn = customized_flash_attn and customized_kernel_installed + if customized_flash_attn and not customized_kernel_installed: + import inspect, warnings + file_path = inspect.getsourcefile(flash_attn_func) + line_number = inspect.getsourcelines(flash_attn_func)[1] + info = ( + f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' + f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' + f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' + ) + warnings.warn(info, ImportWarning) + print(info, flush=True) + + self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying + self.first_l = 1 + # solve top-p top-k sampling hyperparameters + self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) + if self.top_p < 1e-5: self.top_p = 0 + if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 + + t = torch.zeros(dist.get_world_size(), device=dist.get_device()) + t[dist.get_rank()] = float(flash_fused_op_installed) + dist.barrier() + dist.allreduce(t) + assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' + + super().__init__() + self.rng = torch.Generator(device=dist.get_device()) + self.maybe_record_function = nullcontext + self.text_maxlen = text_maxlen + self.t2i = text_channels != 0 + + # [inp & position embedding] + init_std = math.sqrt(1 / self.C / 3) + self.norm0_cond = nn.Identity() + if self.t2i: + self.selecting_idx = None + self.num_classes = 0 + self.D = self.C + + cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) + rng = torch.Generator(device='cpu') + rng.manual_seed(0) + torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) + cfg_uncond /= self.Ct5 ** 0.5 + if rand_uncond: + self.register_buffer('cfg_uncond', cfg_uncond) + else: + self.cfg_uncond = nn.Parameter(cfg_uncond) + + self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) + self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) + self.text_proj_for_ca = nn.Sequential( + nn.Linear(self.Ct5, self.D), + nn.GELU(approximate='tanh'), + nn.Linear(self.D, self.D), + ) + else: # class-label cond + if selecting_idx is None: + num_classes = 1000 + print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') + selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) + self.selecting_idx = selecting_idx + self.num_classes = selecting_idx.shape[-1] + self.D = self.C + self.class_emb = nn.Embedding(self.num_classes + 1, self.C) + nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) + + self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) + nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) + if self.rope2d_each_sa_layer: + rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) + self.rope2d_freqs_grid = rope2d_freqs_grid + else: + raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') + self.lvl_embed = nn.Embedding(15, self.C) + nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) + + # [input layers] input norm && input embedding + norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) + self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() + self.word_embed = nn.Linear(self.d_vae, self.C) + + # [shared adaptive layernorm mapping network] + self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() + + # fused norm + if fused_norm: + fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm + if fused_norm_func is not None: # pre-compile + B = 2 + x = torch.randn(B, 1, self.C).requires_grad_(True) + scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward() + del B, x, scale, shift + else: + fused_norm_func = None + + # [backbone and head] + self.use_flex_attn = use_flex_attn + self.attn_fn_compile_dict = {} + self.batch_size = batch_size + if self.use_flex_attn: + self.attn_fn_compile_dict = self.compile_flex_attn() + + self.drop_path_rate = drop_path_rate + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing) + self.unregistered_blocks = [] + for block_idx in range(depth): + block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( + embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, + num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, + swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, + checkpointing_sa_only=self.checkpointing == 'self-attn', + use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, + ) + self.unregistered_blocks.append(block) + + # [head] + V = self.V + if head_aln: + self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) + self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + else: + self.head_nm = MultiInpIdentity() + self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + + self.num_block_chunks = block_chunks or 1 + self.num_blocks_in_a_chunk = depth // block_chunks + print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") + assert self.num_blocks_in_a_chunk * block_chunks == depth + if self.num_block_chunks == 1: + self.blocks = nn.ModuleList(self.unregistered_blocks) + else: + self.block_chunks = nn.ModuleList() + for i in range(self.num_block_chunks): + self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) + + self.lq_linear = nn.Linear(self.C, self.C) + print( + f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' + f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' + f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', + end='\n\n', flush=True + ) + + + def compile_flex_attn(self): + attn_fn_compile_dict = {} + for h_div_w in self.train_h_div_w_list: + h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] + full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] + if self.inference_mode: + apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) + mask_type = "infinity_infer_mask_with_kv_cache" + auto_padding = True + else: + mask_type = 'var' + auto_padding = False + apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] + for scales_num in apply_flex_attn_scales: + print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') + scale_schedule = full_scale_schedule[:scales_num] + scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L, + auto_padding=auto_padding) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + + if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos) + scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + return attn_fn_compile_dict + + def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): + """ + :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) + :param cond_BD: shaped (B or batch_size, D or cond_dim) + :param tau: temperature + :return: logits, shaped (B or batch_size, V or vocabulary_size) + """ + with torch.amp.autocast('cuda', enabled=False): + return self.head(self.head_nm(h.float(), cond_BD.float())) + + def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): + bs, seq_len, c = feature.shape + patch_t, patch_h, patch_w = scale_schedule[scale_ind] + t_mul_h_mul_w = patch_t * patch_h * patch_w + assert t_mul_h_mul_w + need_to_pad == seq_len + feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) + return feature + + def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): + ptr = 0 + x_BLC_list = [] + for scale_ind, patch_t_h_w in enumerate(scale_schedule): + scale_seq_len = np.array(patch_t_h_w).prod() + x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c] + ptr += scale_seq_len + x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) + x_BLC_list.append(x_BLC_this_scale) + assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' + x_BLC_list.append(x_BLC[:,ptr:]) + x_BLC = torch.cat(x_BLC_list, dim=1) + return x_BLC + + def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False, x_BLC_w_prefix_lq=None, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) + + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + ##### my code + x_BLC_lq = self.lq_linear(x_BLC_lq) + x_BLC = x_BLC + x_BLC_lq + ##### my code + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + + # [3. unpad the seqlen dim, and then get logits] + return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + @torch.no_grad() + def autoregressive_infer_cfg( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + x_BLC_w_prefix_lq=None, + # x_BLC_wo_prefix=None, + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + ##### + x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() + x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) + x_BLC_lq = self.lq_linear(x_BLC_lq) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) + ##### + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + + # x_BLC = torch.cat((last_stage, self.word_embed(self.norm0_ve(x_BLC_wo_prefix.float()))), dim=1) + # x_BLC_list = list(torch.split(x_BLC,patch_nums_per_level,dim=1)) + + for si, pn in enumerate(scale_schedule): # si: i-th segment + cfg = cfg_list[si] + if si >= trunk_scale: + break + cur_L += np.array(pn).prod() + + # last_stage = x_BLC_list[si] + last_stage = last_stage + x_BLC_lq_list[si] + + need_to_pad = 0 + attn_fn = None + if self.use_flex_attn: + # need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier + # if need_to_pad: + # last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad)) + attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) + + # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item' + layer_idx = 0 + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + if not self.add_lvl_embeding_only_first_block: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + for m in b.module: + last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] + last_stage = torch.cat((last_stage, last_stage), 0) + layer_idx += 1 + + if (cfg != 1) and add_cfg_on_logits: + # print(f'add cfg on add_cfg_on_logits') + logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + if vae_type != 0: + assert returns_vemb + if si < gt_leak: + idx_Bld = gt_ls_Bl[si] + else: + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + idx_Bld_list.append(idx_Bld) + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w] + if self.apply_spatial_patchify: # patchify operation + last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w] + last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w] + last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d] + else: + summed_codes += codes + else: + if si < gt_leak: + idx_Bl = gt_ls_Bl[si] + h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC + + # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1]) + h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) + ret.append(h_BChw if returns_vemb != 0 else idx_Bl) + idx_Bl_list.append(idx_Bl) + if si != num_stages_minus_1: + accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) + + if si != num_stages_minus_1: + last_stage = self.word_embed(self.norm0_ve(last_stage)) + last_stage = last_stage.repeat(bs//B, 1, 1) + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + if not ret_img: + return ret, idx_Bl_list, [] + + if vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + else: + img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return ret, idx_Bl_list, img + + def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): + # logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) + + B = logits_BlV_all.shape[0] + + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + summed_codes = 0 + num_stages_minus_1 = len(scale_schedule)-1 + + for si,logits_BlV in enumerate(logits_BlV_list): + pn= scale_schedule[si] + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + #idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + # ##### + # idx_Bld = STGumbelArgmax.apply(logits_BlV, 0.5) + # tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + # tmp_tensor[:,:,1:]=1 + # idx_Bld = idx_Bld * tmp_tensor + # idx_Bld = idx_Bld.sum(dim=-1) + # ##### + + ##### + idx_Bld = GumbelArgmax(logits_BlV, 0.5) + tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + tmp_tensor[:,:,1:]=1 + idx_Bld = idx_Bld * tmp_tensor + idx_Bld = idx_Bld.sum(dim=-1) + ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### vae_type!=0 + ###si>=gt_leak + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + else: + summed_codes += codes + + # if inference_mode: + # for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + # else: + # assert self.num_block_chunks > 1 + # for block_chunk_ in self.block_chunks: + # for module in block_chunk_.module.module: + # (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + #vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return img + + def logits_to_img_gumble(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): + # logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) + + B = logits_BlV_all.shape[0] + + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + summed_codes = 0 + num_stages_minus_1 = len(scale_schedule)-1 + + for si,logits_BlV in enumerate(logits_BlV_list): + pn= scale_schedule[si] + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + # ##### + # idx_Bld = GumbelArgmax(logits_BlV, 0.5) + # tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + # tmp_tensor[:,:,1:]=1 + # idx_Bld = idx_Bld * tmp_tensor + # idx_Bld = idx_Bld.sum(dim=-1) + # ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### vae_type!=0 + ###si>=gt_leak + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + else: + summed_codes += codes + + # if inference_mode: + # for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + # else: + # assert self.num_block_chunks > 1 + # for block_chunk_ in self.block_chunks: + # for module in block_chunk_.module.module: + # (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + #vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return img + + @for_visualize + def vis_key_params(self, ep): + return + + def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): + for k in state_dict: + if 'cfg_uncond' in k: + old, new = state_dict[k], self.cfg_uncond.data + min_tlen = min(old.shape[0], new.shape[0]) + if min_tlen == old.shape[0]: + state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) + else: + state_dict[k] = old[:min_tlen] + + for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): + state_dict.pop(buf_name, None) + if hasattr(self, buf_name): + state_dict[buf_name] = getattr(self, buf_name) + + return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) + + def special_init( + self, + aln_init: float, + aln_gamma_init: float, + scale_head: float, + scale_proj: int, + ): + # init head's norm + if isinstance(self.head_nm, AdaLNBeforeHead): + self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head + if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: + self.head_nm.ada_lin[-1].bias.data.zero_() + + # init head's proj + if scale_head >= 0: + if isinstance(self.head, nn.Linear): + self.head.weight.data.mul_(scale_head) + self.head.bias.data.zero_() + elif isinstance(self.head, nn.Sequential): + self.head[-1].weight.data.mul_(scale_head) + self.head[-1].bias.data.zero_() + + depth = len(self.unregistered_blocks) + for block_idx, sab in enumerate(self.unregistered_blocks): + sab: Union[SelfAttnBlock, CrossAttnBlock] + # init proj + scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) + if scale_proj == 1: + if self.t2i: + sab.sa.proj.weight.data.mul_(scale) + sab.ca.proj.weight.data.mul_(scale) + else: + sab.attn.proj.weight.data.mul_(scale) + sab.ffn.fc2.weight.data.mul_(scale) + # if sab.using_swiglu: + # nn.init.ones_(sab.ffn.fcg.bias) + # nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5) + + # init ada_lin + if hasattr(sab, 'ada_lin'): + lin = sab.ada_lin[-1] + lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma + lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift + if hasattr(lin, 'bias') and lin.bias is not None: + lin.bias.data.zero_() + elif hasattr(sab, 'ada_gss'): + sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma + sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift + + def extra_repr(self): + return f'drop_path_rate={self.drop_path_rate}' + + def get_layer_id_and_scale_exp(self, para_name: str): + raise NotImplementedError + +class FAInfinity(nn.Module): # x_BLC add x_BLC_lq + def __init__( + self, vae_local, + text_channels=0, text_maxlen=0, # text-cond generation + selecting_idx=None, # class-cond generation + embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture + drop_rate=0., drop_path_rate=0., # drop out and drop path + norm_eps=1e-6, rms_norm=False, # norm layer + shared_aln=False, head_aln=True, # adaptive norm + cond_drop_rate=0.1, # for classifier-free guidance + rand_uncond=False, + cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, + raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), + head_depth=1, + top_p=0.0, top_k=0.0, + customized_flash_attn=False, fused_mlp=False, fused_norm=False, + block_chunks=1, + checkpointing=None, + pad_to_multiplier=0, + use_flex_attn=False, + batch_size=2, + add_lvl_embeding_only_first_block=1, + use_bit_label=1, + rope2d_each_sa_layer=0, + rope2d_normalized_by_hw=0, + pn=None, + train_h_div_w_list=None, + video_frames=1, + always_training_scales=20, + apply_spatial_patchify = 0, + inference_mode=False, + ): + # set hyperparameters + self.C = embed_dim + self.inference_mode = inference_mode + self.apply_spatial_patchify = apply_spatial_patchify + if self.apply_spatial_patchify: + self.d_vae = vae_local.embed_dim * 4 + else: + self.d_vae = vae_local.embed_dim + self.use_bit_label = use_bit_label + self.codebook_dim = self.d_vae + self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size + self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None + self.Ct5 = text_channels + self.depth = depth + self.num_heads = num_heads + self.batch_size = batch_size + self.mlp_ratio = mlp_ratio + self.cond_drop_rate = cond_drop_rate + self.norm_eps = norm_eps + self.prog_si = -1 + self.pn = pn + self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates + self.video_frames = video_frames + self.always_training_scales = always_training_scales + + assert add_lvl_embeding_only_first_block in [0,1] + self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block + assert rope2d_each_sa_layer in [0,1] + self.rope2d_each_sa_layer = rope2d_each_sa_layer + self.rope2d_normalized_by_hw = rope2d_normalized_by_hw + print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ + self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') + head_up_method = '' + word_patch_size = 1 if head_up_method in {'', 'no'} else 2 + if word_patch_size > 1: + assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' + + self.checkpointing = checkpointing + self.pad_to_multiplier = max(1, pad_to_multiplier) + + customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) + self.customized_flash_attn = customized_flash_attn and customized_kernel_installed + if customized_flash_attn and not customized_kernel_installed: + import inspect, warnings + file_path = inspect.getsourcefile(flash_attn_func) + line_number = inspect.getsourcelines(flash_attn_func)[1] + info = ( + f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' + f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' + f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' + ) + warnings.warn(info, ImportWarning) + print(info, flush=True) + + self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying + self.first_l = 1 + # solve top-p top-k sampling hyperparameters + self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) + if self.top_p < 1e-5: self.top_p = 0 + if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 + + t = torch.zeros(dist.get_world_size(), device=dist.get_device()) + t[dist.get_rank()] = float(flash_fused_op_installed) + dist.barrier() + dist.allreduce(t) + assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' + + super().__init__() + self.rng = torch.Generator(device=dist.get_device()) + self.maybe_record_function = nullcontext + self.text_maxlen = text_maxlen + self.t2i = text_channels != 0 + + # [inp & position embedding] + init_std = math.sqrt(1 / self.C / 3) + self.norm0_cond = nn.Identity() + if self.t2i: + self.selecting_idx = None + self.num_classes = 0 + self.D = self.C + + cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) + rng = torch.Generator(device='cpu') + rng.manual_seed(0) + torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) + cfg_uncond /= self.Ct5 ** 0.5 + if rand_uncond: + self.register_buffer('cfg_uncond', cfg_uncond) + else: + self.cfg_uncond = nn.Parameter(cfg_uncond) + + self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) + self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) + self.text_proj_for_ca = nn.Sequential( + nn.Linear(self.Ct5, self.D), + nn.GELU(approximate='tanh'), + nn.Linear(self.D, self.D), + ) + else: # class-label cond + if selecting_idx is None: + num_classes = 1000 + print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') + selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) + self.selecting_idx = selecting_idx + self.num_classes = selecting_idx.shape[-1] + self.D = self.C + self.class_emb = nn.Embedding(self.num_classes + 1, self.C) + nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) + + self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) + nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) + if self.rope2d_each_sa_layer: + rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) + self.rope2d_freqs_grid = rope2d_freqs_grid + else: + raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') + self.lvl_embed = nn.Embedding(15, self.C) + nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) + + # [input layers] input norm && input embedding + norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) + self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() + self.word_embed = nn.Linear(self.d_vae, self.C) + + # [shared adaptive layernorm mapping network] + self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() + + # fused norm + if fused_norm: + fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm + if fused_norm_func is not None: # pre-compile + B = 2 + x = torch.randn(B, 1, self.C).requires_grad_(True) + scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward() + del B, x, scale, shift + else: + fused_norm_func = None + + # [backbone and head] + self.use_flex_attn = use_flex_attn + self.attn_fn_compile_dict = {} + self.batch_size = batch_size + if self.use_flex_attn: + self.attn_fn_compile_dict = self.compile_flex_attn() + + self.drop_path_rate = drop_path_rate + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing) + self.unregistered_blocks = [] + for block_idx in range(depth): + block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( + embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, + num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, + swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, + checkpointing_sa_only=self.checkpointing == 'self-attn', + use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, + ) + self.unregistered_blocks.append(block) + + # [head] + V = self.V + if head_aln: + self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) + self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + else: + self.head_nm = MultiInpIdentity() + self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + + self.num_block_chunks = block_chunks or 1 + self.num_blocks_in_a_chunk = depth // block_chunks + print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") + assert self.num_blocks_in_a_chunk * block_chunks == depth + if self.num_block_chunks == 1: + self.blocks = nn.ModuleList(self.unregistered_blocks) + else: + self.block_chunks = nn.ModuleList() + for i in range(self.num_block_chunks): + self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) + + ##### + self.time_embed = nn.Sequential( + nn.Linear(self.C//4, self.C), + nn.SiLU(), + nn.Linear(self.C, self.C), + ) + self.lq_linear = nn.Linear(self.C, self.C) + ##### + + print( + f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' + f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' + f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', + end='\n\n', flush=True + ) + + + def compile_flex_attn(self): + attn_fn_compile_dict = {} + for h_div_w in self.train_h_div_w_list: + h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] + full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] + if self.inference_mode: + apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) + mask_type = "infinity_infer_mask_with_kv_cache" + auto_padding = True + else: + mask_type = 'var' + auto_padding = False + apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] + for scales_num in apply_flex_attn_scales: + print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') + scale_schedule = full_scale_schedule[:scales_num] + scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L, + auto_padding=auto_padding) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + + if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos) + scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + return attn_fn_compile_dict + + def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): + """ + :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) + :param cond_BD: shaped (B or batch_size, D or cond_dim) + :param tau: temperature + :return: logits, shaped (B or batch_size, V or vocabulary_size) + """ + with torch.amp.autocast('cuda', enabled=False): + return self.head(self.head_nm(h.float(), cond_BD.float())) + + def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): + bs, seq_len, c = feature.shape + patch_t, patch_h, patch_w = scale_schedule[scale_ind] + t_mul_h_mul_w = patch_t * patch_h * patch_w + assert t_mul_h_mul_w + need_to_pad == seq_len + feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) + return feature + + def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): + ptr = 0 + x_BLC_list = [] + for scale_ind, patch_t_h_w in enumerate(scale_schedule): + scale_seq_len = np.array(patch_t_h_w).prod() + x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c] + ptr += scale_seq_len + x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) + x_BLC_list.append(x_BLC_this_scale) + assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' + x_BLC_list.append(x_BLC[:,ptr:]) + x_BLC = torch.cat(x_BLC_list, dim=1) + return x_BLC + + def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False, + x_BLC_w_prefix_lq=None, + index=None, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + ###### + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) + + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] + + noise = torch.randn_like(x_BLC).to(x_BLC.device) + mask = torch.zeros_like(x_BLC, dtype=torch.bool).to(x_BLC.device) + index_list = index.cpu().tolist() + patch_nums_per_batch = [patch_nums_per_level_acc[j] for j in index_list] + for j in range(len(patch_nums_per_batch)): + p = patch_nums_per_batch[j] + mask[j, :p, :] = 1 + x_BLC = torch.where(mask, x_BLC, noise) + ##### + + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + ##### my code + t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) + t_emb = self.time_embed(t_emb) + + x_BLC_lq = self.lq_linear(x_BLC_lq) + x_BLC = x_BLC + x_BLC_lq + # x_BLC = x_BLC + t_emb.unsqueeze(1) + ##### my code + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, None, None, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + + # [3. unpad the seqlen dim, and then get logits] + return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + @torch.no_grad() + def autoregressive_infer_cfg( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + x_BLC_w_prefix_lq=None, + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + + ### need to change + cfg = cfg_list[0] + ### need to change + + # x_BLC_lq = self.car_control_convs(lq_images) + # x_BLC_lq = x_BLC_lq.view(B, self.C, -1).transpose(1, 2).contiguous() + # if cfg!=1: + # x_BLC_lq = torch.cat([x_BLC_lq,x_BLC_lq],dim=0) + + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] + x_BLC = torch.randn((bs,patch_nums_per_level_acc[-1],last_stage.shape[-1])).to(last_stage.device) + x_BLC[:,:1,:] = last_stage + l_end = x_BLC.shape[1] + + + x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() + x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) + x_BLC_lq = self.lq_linear(x_BLC_lq) + x_BLC = x_BLC + x_BLC_lq + + index = torch.zeros((bs,)).to(x_BLC.device) ###change from torch.ones to torch.zeros + t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) + t_emb = self.time_embed(t_emb) + + # x_BLC = x_BLC + t_emb.unsqueeze(1) + + + layer_idx = 0 + + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) + for m in b.module: + ### need to change scale_ind = si + x_BLC = m(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=0) + ### need to change + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + x_BLC = cfg * x_BLC[:B] + (1-cfg) * x_BLC[B:] + x_BLC = torch.cat((x_BLC, x_BLC), 0) + layer_idx += 1 + + ### need to change + if (cfg != 1) and add_cfg_on_logits: + logits_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD).mul(1/tau_list[0]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(x_BLC[:B, :l_end], cond_BD[:B]).mul(1/tau_list[0]) + ### need to change + + img = self.logits_to_img_discrete(logits_BlV_all=logits_BlV, + vae=vae, + scale_schedule=scale_schedule, + top_k=top_k, + top_p=top_p, + g_seed=g_seed) + + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + + return None,None,img + + @torch.no_grad() + def autoregressive_infer_cfg_multi_step( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + x_BLC_w_prefix_lq=None, + index_list=None + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + assert index_list[0] == 0 + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + + ### need to change + cfg = cfg_list[0] + ### need to change + + # x_BLC_lq = self.car_control_convs(lq_images) + + # x_BLC_lq = x_BLC_lq.view(B, self.C, -1).transpose(1, 2).contiguous() + # if cfg!=1: + # x_BLC_lq = torch.cat([x_BLC_lq,x_BLC_lq],dim=0) + + x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() + x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) + x_BLC_lq = self.lq_linear(x_BLC_lq) + + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] + + x_BLC = torch.zeros((bs,patch_nums_per_level_acc[-1],last_stage.shape[-1])).to(last_stage.device) + x_BLC[:,:1,:] = last_stage + l_end = x_BLC.shape[1] + + + for index in index_list: + + # add noise + noise = torch.randn_like(x_BLC).to(x_BLC.device) + x_BLC[:,patch_nums_per_level_acc[index]:,:] = noise[:,patch_nums_per_level_acc[index]:,:] + + # add x_BLC_lq + x_BLC = x_BLC + x_BLC_lq + + # add time_embedding + index_tensor = torch.full((bs,),index).to(x_BLC.device) + t_emb = dist.timestep_embedding(index_tensor, self.C//4, repeat_only=False) + t_emb = self.time_embed(t_emb) + + # x_BLC = x_BLC + t_emb + + layer_idx = 0 + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) + for m in b.module: + ### need to change scale_ind = si + x_BLC = m(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=0) + ### need to change + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + x_BLC = cfg * x_BLC[:B] + (1-cfg) * x_BLC[B:] + x_BLC = torch.cat((x_BLC, x_BLC), 0) + layer_idx += 1 + + ### need to change + if (cfg != 1) and add_cfg_on_logits: + logits_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD).mul(1/tau_list[0]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(x_BLC[:B, :l_end], cond_BD[:B]).mul(1/tau_list[0]) + + if index == index_list[-1]: + logits_final = logits_BlV + else: + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV= logits_BlV.reshape(tmp_bs, -1, 2) + idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + cum_var_input = 0 + x_BLC_wo_prefix = [] + idx_Bld_list = list(torch.split(idx_Bld,patch_nums_per_level,dim=1)) + for si, bit_indices in enumerate(idx_Bld_list): + + _, _, d_vae = bit_indices.shape + bit_indices = bit_indices.reshape((B,vae_scale_schedule[si][0],vae_scale_schedule[si][1],vae_scale_schedule[si][2],d_vae)) + + quantized = vae.quantizer.lfq.indices_to_codes(bit_indices, label_type='bit_label') + quantized_up = F.interpolate(quantized, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + cum_var_input += quantized_up + + if si < len(vae_scale_schedule)-1: + this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_up) + if self.apply_spatial_patchify: + this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) + x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) + + x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1) + x_BLC = torch.cat((last_stage, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + assert x_BLC.shape[1] == l_end + + + img = self.logits_to_img_discrete(logits_BlV_all=logits_final, + vae=vae, + scale_schedule=scale_schedule, + top_k=top_k, + top_p=top_p, + g_seed=g_seed) + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + + return None,None,img + + def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): + # logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) + + B = logits_BlV_all.shape[0] + + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + summed_codes = 0 + num_stages_minus_1 = len(scale_schedule)-1 + + for si,logits_BlV in enumerate(logits_BlV_list): + pn= scale_schedule[si] + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + #idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### + idx_Bld = GumbelArgmax(logits_BlV, 0.5) + tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + tmp_tensor[:,:,1:]=1 + idx_Bld = idx_Bld * tmp_tensor + idx_Bld = idx_Bld.sum(dim=-1) + ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### vae_type!=0 + ###si>=gt_leak + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + else: + summed_codes += codes + + # if inference_mode: + # for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + # else: + # assert self.num_block_chunks > 1 + # for block_chunk_ in self.block_chunks: + # for module in block_chunk_.module.module: + # (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + #vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return img + + def logits_to_img_discrete(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): + # logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) + + B = logits_BlV_all.shape[0] + + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + summed_codes = 0 + num_stages_minus_1 = len(scale_schedule)-1 + + for si,logits_BlV in enumerate(logits_BlV_list): + pn= scale_schedule[si] + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + # ##### + # idx_Bld = GumbelArgmax(logits_BlV, 0.5) + # tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + # tmp_tensor[:,:,1:]=1 + # idx_Bld = idx_Bld * tmp_tensor + # idx_Bld = idx_Bld.sum(dim=-1) + # ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### vae_type!=0 + ###si>=gt_leak + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + else: + summed_codes += codes + + # if inference_mode: + # for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + # else: + # assert self.num_block_chunks > 1 + # for block_chunk_ in self.block_chunks: + # for module in block_chunk_.module.module: + # (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + #vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return img + + + @for_visualize + def vis_key_params(self, ep): + return + + def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): + for k in state_dict: + if 'cfg_uncond' in k: + old, new = state_dict[k], self.cfg_uncond.data + min_tlen = min(old.shape[0], new.shape[0]) + if min_tlen == old.shape[0]: + state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) + else: + state_dict[k] = old[:min_tlen] + + for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): + state_dict.pop(buf_name, None) + if hasattr(self, buf_name): + state_dict[buf_name] = getattr(self, buf_name) + + return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) + + def special_init( + self, + aln_init: float, + aln_gamma_init: float, + scale_head: float, + scale_proj: int, + ): + # init head's norm + if isinstance(self.head_nm, AdaLNBeforeHead): + self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head + if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: + self.head_nm.ada_lin[-1].bias.data.zero_() + + # init head's proj + if scale_head >= 0: + if isinstance(self.head, nn.Linear): + self.head.weight.data.mul_(scale_head) + self.head.bias.data.zero_() + elif isinstance(self.head, nn.Sequential): + self.head[-1].weight.data.mul_(scale_head) + self.head[-1].bias.data.zero_() + + depth = len(self.unregistered_blocks) + for block_idx, sab in enumerate(self.unregistered_blocks): + sab: Union[SelfAttnBlock, CrossAttnBlock] + # init proj + scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) + if scale_proj == 1: + if self.t2i: + sab.sa.proj.weight.data.mul_(scale) + sab.ca.proj.weight.data.mul_(scale) + else: + sab.attn.proj.weight.data.mul_(scale) + sab.ffn.fc2.weight.data.mul_(scale) + # if sab.using_swiglu: + # nn.init.ones_(sab.ffn.fcg.bias) + # nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5) + + # init ada_lin + if hasattr(sab, 'ada_lin'): + lin = sab.ada_lin[-1] + lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma + lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift + if hasattr(lin, 'bias') and lin.bias is not None: + lin.bias.data.zero_() + elif hasattr(sab, 'ada_gss'): + sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma + sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift + + def extra_repr(self): + return f'drop_path_rate={self.drop_path_rate}' + + def get_layer_id_and_scale_exp(self, para_name: str): + raise NotImplementedError + +class FInfinity(nn.Module): + def __init__( + self, vae_local, + text_channels=0, text_maxlen=0, # text-cond generation + selecting_idx=None, # class-cond generation + embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture + drop_rate=0., drop_path_rate=0., # drop out and drop path + norm_eps=1e-6, rms_norm=False, # norm layer + shared_aln=False, head_aln=True, # adaptive norm + cond_drop_rate=0.1, # for classifier-free guidance + rand_uncond=False, + cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, + raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), + head_depth=1, + top_p=0.0, top_k=0.0, + customized_flash_attn=False, fused_mlp=False, fused_norm=False, + block_chunks=1, + checkpointing=None, + pad_to_multiplier=0, + use_flex_attn=False, + batch_size=2, + add_lvl_embeding_only_first_block=1, + use_bit_label=1, + rope2d_each_sa_layer=0, + rope2d_normalized_by_hw=0, + pn=None, + train_h_div_w_list=None, + video_frames=1, + always_training_scales=20, + apply_spatial_patchify = 0, + inference_mode=False, + ): + # set hyperparameters + self.C = embed_dim + self.inference_mode = inference_mode + self.apply_spatial_patchify = apply_spatial_patchify + if self.apply_spatial_patchify: + self.d_vae = vae_local.embed_dim * 4 + else: + self.d_vae = vae_local.embed_dim + self.use_bit_label = use_bit_label + self.codebook_dim = self.d_vae + self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size + self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None + self.Ct5 = text_channels + self.depth = depth + self.num_heads = num_heads + self.batch_size = batch_size + self.mlp_ratio = mlp_ratio + self.cond_drop_rate = cond_drop_rate + self.norm_eps = norm_eps + self.prog_si = -1 + self.pn = pn + self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates + self.video_frames = video_frames + self.always_training_scales = always_training_scales + + assert add_lvl_embeding_only_first_block in [0,1] + self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block + assert rope2d_each_sa_layer in [0,1] + self.rope2d_each_sa_layer = rope2d_each_sa_layer + self.rope2d_normalized_by_hw = rope2d_normalized_by_hw + print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \ + self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}') + head_up_method = '' + word_patch_size = 1 if head_up_method in {'', 'no'} else 2 + if word_patch_size > 1: + assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}' + + self.checkpointing = checkpointing + self.pad_to_multiplier = max(1, pad_to_multiplier) + + customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames) + self.customized_flash_attn = customized_flash_attn and customized_kernel_installed + if customized_flash_attn and not customized_kernel_installed: + import inspect, warnings + file_path = inspect.getsourcefile(flash_attn_func) + line_number = inspect.getsourcelines(flash_attn_func)[1] + info = ( + f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n' + f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n' + f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n' + ) + warnings.warn(info, ImportWarning) + print(info, flush=True) + + self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying + self.first_l = 1 + # solve top-p top-k sampling hyperparameters + self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k)) + if self.top_p < 1e-5: self.top_p = 0 + if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0 + + t = torch.zeros(dist.get_world_size(), device=dist.get_device()) + t[dist.get_rank()] = float(flash_fused_op_installed) + dist.barrier() + dist.allreduce(t) + assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}' + + super().__init__() + self.rng = torch.Generator(device=dist.get_device()) + self.maybe_record_function = nullcontext + self.text_maxlen = text_maxlen + self.t2i = text_channels != 0 + + # [inp & position embedding] + init_std = math.sqrt(1 / self.C / 3) + self.norm0_cond = nn.Identity() + + self.time_embed = nn.Sequential( + nn.Linear(self.C//4, self.C), + nn.SiLU(), + nn.Linear(self.C, self.C), + ) + + if self.t2i: + self.selecting_idx = None + self.num_classes = 0 + self.D = self.C + + cfg_uncond = torch.empty(self.text_maxlen, self.Ct5) + rng = torch.Generator(device='cpu') + rng.manual_seed(0) + torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng) + cfg_uncond /= self.Ct5 ** 0.5 + if rand_uncond: + self.register_buffer('cfg_uncond', cfg_uncond) + else: + self.cfg_uncond = nn.Parameter(cfg_uncond) + + self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps) + self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D) + self.text_proj_for_ca = nn.Sequential( + nn.Linear(self.Ct5, self.D), + nn.GELU(approximate='tanh'), + nn.Linear(self.D, self.D), + ) + else: # class-label cond + if selecting_idx is None: + num_classes = 1000 + print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======') + selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device()) + self.selecting_idx = selecting_idx + self.num_classes = selecting_idx.shape[-1] + self.D = self.C + self.class_emb = nn.Embedding(self.num_classes + 1, self.C) + nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) + + self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) + nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) + if self.rope2d_each_sa_layer: + rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw) + self.rope2d_freqs_grid = rope2d_freqs_grid + else: + raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented') + self.lvl_embed = nn.Embedding(15, self.C) + nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) + + # [input layers] input norm && input embedding + norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps) + self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity() + self.word_embed = nn.Linear(self.d_vae, self.C) + + #my code + self.norm0_ve_lq = norm_layer(self.d_vae) if nm0 else nn.Identity() + self.word_embed_lq = nn.Linear(self.d_vae, self.C) + + # [shared adaptive layernorm mapping network] + self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity() + + # fused norm + if fused_norm: + fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm + if fused_norm_func is not None: # pre-compile + B = 2 + x = torch.randn(B, 1, self.C).requires_grad_(True) + scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward() + del B, x, scale, shift + else: + fused_norm_func = None + + # [backbone and head] + self.use_flex_attn = use_flex_attn + self.attn_fn_compile_dict = {} + self.batch_size = batch_size + if self.use_flex_attn: + self.attn_fn_compile_dict = self.compile_flex_attn() + + self.drop_path_rate = drop_path_rate + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing) + self.unregistered_blocks = [] + for block_idx in range(depth): + block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( + embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, + num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, + swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, + checkpointing_sa_only=self.checkpointing == 'self-attn', + use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, + ) + self.unregistered_blocks.append(block) + + # [head] + V = self.V + if head_aln: + self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func) + self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + else: + self.head_nm = MultiInpIdentity() + self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V)) + + self.num_block_chunks = block_chunks or 1 + self.num_blocks_in_a_chunk = depth // block_chunks + print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}") + assert self.num_blocks_in_a_chunk * block_chunks == depth + if self.num_block_chunks == 1: + self.blocks = nn.ModuleList(self.unregistered_blocks) + else: + self.block_chunks = nn.ModuleList() + for i in range(self.num_block_chunks): + self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) + print( + f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n' + f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n' + f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', + end='\n\n', flush=True + ) + self.car_control_convs = ControlConditionEmbedding(conditioning_embedding_channels=self.C) + # self.car_control_convs = ControlConditionEmbedding_patch_size_32(conditioning_embedding_channels=self.C) + + + def compile_flex_attn(self): + attn_fn_compile_dict = {} + for h_div_w in self.train_h_div_w_list: + h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))] + full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales'] + if self.inference_mode: + apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule))) + mask_type = "infinity_infer_mask_with_kv_cache" + auto_padding = True + else: + mask_type = 'var' + auto_padding = False + apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))] + for scales_num in apply_flex_attn_scales: + print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======') + scale_schedule = full_scale_schedule[:scales_num] + scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L, + auto_padding=auto_padding) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + + if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos) + scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule] + patchs_nums_tuple = tuple(scale_schedule) + SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple) + aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L + attn_fn = FlexAttn(block_scales = patchs_nums_tuple, + mask_type = mask_type, + B = self.batch_size, + H = self.num_heads, + L = aligned_L) + attn_fn_compile_dict[patchs_nums_tuple] = attn_fn + return attn_fn_compile_dict + + def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): + """ + :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim) + :param cond_BD: shaped (B or batch_size, D or cond_dim) + :param tau: temperature + :return: logits, shaped (B or batch_size, V or vocabulary_size) + """ + with torch.amp.autocast('cuda', enabled=False): + return self.head(self.head_nm(h.float(), cond_BD.float())) + + def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): + bs, seq_len, c = feature.shape + patch_t, patch_h, patch_w = scale_schedule[scale_ind] + t_mul_h_mul_w = patch_t * patch_h * patch_w + assert t_mul_h_mul_w + need_to_pad == seq_len + feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device)) + return feature + + def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0): + ptr = 0 + x_BLC_list = [] + for scale_ind, patch_t_h_w in enumerate(scale_schedule): + scale_seq_len = np.array(patch_t_h_w).prod() + x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c] + ptr += scale_seq_len + x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule) + x_BLC_list.append(x_BLC_this_scale) + + # assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}' + + x_BLC_list.append(x_BLC[:,ptr:]) + x_BLC = torch.cat(x_BLC_list, dim=1) + return x_BLC + + # def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + # cfg_infer=False,x_BLC_w_prefix_lq=None,index=None, + # **kwargs, + # ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + # """ + # label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + # :return: logits BLV, V is vocab_size + # """ + # if cfg_infer: + # return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + # x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + # x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() + # B = x_BLC_wo_prefix.shape[0] + + # # [1. get input sequence x_BLC] + # with torch.amp.autocast('cuda', enabled=False): + # kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # # drop cond + # total = 0 + # for le in lens: + # if random.random() < self.cond_drop_rate: + # kv_compact[total:total+le] = self.cfg_uncond[:le] + # total += le + # must_on_graph = self.cfg_uncond[0, 0] * 0 + # kv_compact = self.text_norm(kv_compact).contiguous() + # sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + # kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + # kv_compact[0, 0] += must_on_graph + # ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + # cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + # sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + # x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + # x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) + + # patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + # patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] + + # noise = torch.randn_like(x_BLC).to(x_BLC.device) + # mask = torch.zeros_like(x_BLC, dtype=torch.bool).to(x_BLC.device) + # index_list = index.cpu().tolist() + # patch_nums_per_batch = [patch_nums_per_level_acc[j] for j in index_list] + # for j in range(len(patch_nums_per_batch)): + # p = patch_nums_per_batch[j] + # mask[j, :p, :] = 1 + # x_BLC = torch.where(mask, x_BLC, noise) + + # # [1.1. pad the seqlen dim] + # l_end = x_BLC.shape[1] + # need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + # if self.customized_flash_attn: + # Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + # Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + # attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # # todo: solve need_to_pad here + # elif self.use_flex_attn: + # if need_to_pad: + # x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + # #note x_BLC_lq padding + # x_BLC_lq = F.pad(x_BLC_lq, (0, 0, 0, need_to_pad)) + + # assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + # attn_bias_or_two_vector = None + # else: + # d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + # dT = d.transpose(1, 2) # dT: 11L + # attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + # attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + # if need_to_pad: + # attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + # attn_bias[0, 0, l_end:, 0] = 0 + # #note x_BLC_lq padding + # x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + # x_BLC_lq = F.pad(x_BLC_lq,(0, 0, 0, need_to_pad)) + + # attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + # if self.use_flex_attn: + # attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + # else: + # attn_fn = None + + # # [2. block loop] + # SelfAttnBlock.forward, CrossAttnBlock.forward + # checkpointing_full_block = self.checkpointing == 'full-block' and self.training + + # t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) + # t_emb = self.time_embed(t_emb) + + # if self.num_block_chunks == 1: + # for i, b in enumerate(self.blocks): + # if self.add_lvl_embeding_only_first_block and i == 0: + # x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + # if not self.add_lvl_embeding_only_first_block: + # x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + + # x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) + # # add time embedding + # x_BLC = x_BLC + t_emb + # if checkpointing_full_block: + # x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, None, None, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + # else: + # x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + # else: + # for i, chunk in enumerate(self.block_chunks): # this path + # if self.add_lvl_embeding_only_first_block and i == 0: + # x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + # if not self.add_lvl_embeding_only_first_block: + # x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + + # x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) + # x_BLC = x_BLC + t_emb.unsqueeze(1) + # x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + # # [3. unpad the seqlen dim, and then get logits] + # return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + # def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + # cfg_infer=False,raw_features_seq=None,index=None, #here raw_features_seq is quantized_raw_features_seq + # **kwargs, + # ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + # """ + # label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + # :return: logits BLV, V is vocab_size + # """ + # if cfg_infer: + # return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + # x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + # raw_features_seq = raw_features_seq.float() + # B = x_BLC_wo_prefix.shape[0] + + # # [1. get input sequence x_BLC] + # with torch.amp.autocast('cuda', enabled=False): + # kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # # drop cond + # total = 0 + # for le in lens: + # if random.random() < self.cond_drop_rate: + # kv_compact[total:total+le] = self.cfg_uncond[:le] + # total += le + # must_on_graph = self.cfg_uncond[0, 0] * 0 + # kv_compact = self.text_norm(kv_compact).contiguous() + # sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + # kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + # kv_compact[0, 0] += must_on_graph + # ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + # cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + # sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + # x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + # x_BLC_lq = self.word_embed_lq(self.norm0_ve_lq(raw_features_seq)) + + # patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + # patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] + + # noise = torch.randn_like(x_BLC).to(x_BLC.device) + # mask = torch.zeros_like(x_BLC, dtype=torch.bool).to(x_BLC.device) + # index_list = index.cpu().tolist() + # patch_nums_per_batch = [patch_nums_per_level_acc[j] for j in index_list] + # for j in range(len(patch_nums_per_batch)): + # p = patch_nums_per_batch[j] + # mask[j, :p, :] = 1 + # x_BLC = torch.where(mask, x_BLC, noise) + + # # [1.1. pad the seqlen dim] + # l_end = x_BLC.shape[1] + # need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + # if self.customized_flash_attn: + # Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + # Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + # attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # # todo: solve need_to_pad here + # elif self.use_flex_attn: + # if need_to_pad: + # x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + # #note x_BLC_lq padding + # # x_BLC_lq = F.pad(x_BLC_lq, (0, 0, 0, need_to_pad)) + + # assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + # attn_bias_or_two_vector = None + # else: + # d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + # dT = d.transpose(1, 2) # dT: 11L + # attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + # attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + # if need_to_pad: + # attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + # attn_bias[0, 0, l_end:, 0] = 0 + # #note x_BLC_lq padding + # x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + # # x_BLC_lq = F.pad(x_BLC_lq,(0, 0, 0, need_to_pad)) + + # attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + # if self.use_flex_attn: + # attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + # else: + # attn_fn = None + + # # [2. block loop] + # SelfAttnBlock.forward, CrossAttnBlock.forward + # checkpointing_full_block = self.checkpointing == 'full-block' and self.training + + # t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) + # t_emb = self.time_embed(t_emb) + + # if self.num_block_chunks == 1: + # for i, b in enumerate(self.blocks): + # if self.add_lvl_embeding_only_first_block and i == 0: + # x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + # if not self.add_lvl_embeding_only_first_block: + # x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + + # x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) + # # add time embedding + # x_BLC = x_BLC + t_emb + # if checkpointing_full_block: + # x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, None, None, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + # else: + # x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + # else: + # for i, chunk in enumerate(self.block_chunks): # this path + # if self.add_lvl_embeding_only_first_block and i == 0: + # x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + # if not self.add_lvl_embeding_only_first_block: + # x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + + # x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) + # x_BLC = x_BLC + t_emb.unsqueeze(1) + # x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + # # [3. unpad the seqlen dim, and then get logits] + # return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False,lq_images=None,index=None, #here raw_features_seq is quantized_raw_features_seq + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + lq_images = lq_images.float() + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + x_BLC_lq = self.car_control_convs(lq_images) + x_BLC_lq = x_BLC_lq.view(B, self.C, -1).transpose(1, 2).contiguous() + + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] + + noise = torch.randn_like(x_BLC).to(x_BLC.device) + mask = torch.zeros_like(x_BLC, dtype=torch.bool).to(x_BLC.device) + index_list = index.cpu().tolist() + patch_nums_per_batch = [patch_nums_per_level_acc[j] for j in index_list] + for j in range(len(patch_nums_per_batch)): + p = patch_nums_per_batch[j] + mask[j, :p, :] = 1 + x_BLC = torch.where(mask, x_BLC, noise) + + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + #note x_BLC_lq padding + # x_BLC_lq = F.pad(x_BLC_lq, (0, 0, 0, need_to_pad)) + + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + #note x_BLC_lq padding + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + # x_BLC_lq = F.pad(x_BLC_lq,(0, 0, 0, need_to_pad)) + + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + + t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) + t_emb = self.time_embed(t_emb) + + x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) + x_BLC = x_BLC + t_emb.unsqueeze(1) + + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, None, None, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + # x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + # [3. unpad the seqlen dim, and then get logits] + return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + def logits_to_img(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): + # logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) + + B = logits_BlV_all.shape[0] + + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + summed_codes = 0 + num_stages_minus_1 = len(scale_schedule)-1 + + for si,logits_BlV in enumerate(logits_BlV_list): + pn= scale_schedule[si] + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + #idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + # ##### + # idx_Bld = STGumbelArgmax.apply(logits_BlV, 0.5) + # tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + # tmp_tensor[:,:,1:]=1 + # idx_Bld = idx_Bld * tmp_tensor + # idx_Bld = idx_Bld.sum(dim=-1) + # ##### + + ##### + idx_Bld = GumbelArgmax(logits_BlV, 0.5) + tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + tmp_tensor[:,:,1:]=1 + idx_Bld = idx_Bld * tmp_tensor + idx_Bld = idx_Bld.sum(dim=-1) + ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### vae_type!=0 + ###si>=gt_leak + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + else: + summed_codes += codes + + # if inference_mode: + # for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + # else: + # assert self.num_block_chunks > 1 + # for block_chunk_ in self.block_chunks: + # for module in block_chunk_.module.module: + # (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + #vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return img + + def logits_to_img_discrete(self,logits_BlV_all,vae,scale_schedule,top_k=900,top_p=0.97,g_seed=1): + # logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + logits_BlV_list = list(torch.split(logits_BlV_all,patch_nums_per_level,dim=1)) + + B = logits_BlV_all.shape[0] + + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + summed_codes = 0 + num_stages_minus_1 = len(scale_schedule)-1 + + for si,logits_BlV in enumerate(logits_BlV_list): + pn= scale_schedule[si] + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + + idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + # ##### + # idx_Bld = GumbelArgmax(logits_BlV, 0.5) + # tmp_tensor = torch.zeros_like(idx_Bld).to(idx_Bld.device) + # tmp_tensor[:,:,1:]=1 + # idx_Bld = idx_Bld * tmp_tensor + # idx_Bld = idx_Bld.sum(dim=-1) + # ##### + + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##### vae_type!=0 + ###si>=gt_leak + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + else: + summed_codes += codes + + # if inference_mode: + # for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + # else: + # assert self.num_block_chunks > 1 + # for block_chunk_ in self.block_chunks: + # for module in block_chunk_.module.module: + # (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + #vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return img + + @torch.no_grad() + def autoregressive_infer_cfg( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + lq_images=None, + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + + ### need to change + cfg = cfg_list[0] + ### need to change + + x_BLC_lq = self.car_control_convs(lq_images) + x_BLC_lq = x_BLC_lq.view(B, self.C, -1).transpose(1, 2).contiguous() + if cfg!=1: + x_BLC_lq = torch.cat([x_BLC_lq,x_BLC_lq],dim=0) + + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] + x_BLC = torch.randn((bs,patch_nums_per_level_acc[-1],last_stage.shape[-1])).to(last_stage.device) + x_BLC[:,:1,:] = last_stage + l_end = x_BLC.shape[1] + + index = torch.zeros((bs,)).to(x_BLC.device) ###change from torch.ones to torch.zeros + t_emb = dist.timestep_embedding(index, self.C//4, repeat_only=False) + t_emb = self.time_embed(t_emb) + + x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) + x_BLC = x_BLC + t_emb.unsqueeze(1) + + layer_idx = 0 + + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) + for m in b.module: + ### need to change scale_ind = si + x_BLC = m(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=0) + ### need to change + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + x_BLC = cfg * x_BLC[:B] + (1-cfg) * x_BLC[B:] + x_BLC = torch.cat((x_BLC, x_BLC), 0) + layer_idx += 1 + + ### need to change + if (cfg != 1) and add_cfg_on_logits: + logits_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD).mul(1/tau_list[0]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(x_BLC[:B, :l_end], cond_BD[:B]).mul(1/tau_list[0]) + ### need to change + + img = self.logits_to_img_discrete(logits_BlV_all=logits_BlV, + vae=vae, + scale_schedule=scale_schedule, + top_k=top_k, + top_p=top_p, + g_seed=g_seed) + + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + + return None,None,img + + @torch.no_grad() + def autoregressive_infer_cfg_multi_step( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + lq_images=None, + index_list=None + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + assert index_list[0] == 0 + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + + ### need to change + cfg = cfg_list[0] + ### need to change + + x_BLC_lq = self.car_control_convs(lq_images) + + x_BLC_lq = x_BLC_lq.view(B, self.C, -1).transpose(1, 2).contiguous() + if cfg!=1: + x_BLC_lq = torch.cat([x_BLC_lq,x_BLC_lq],dim=0) + + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + patch_nums_per_level_acc = [np.sum(patch_nums_per_level[:j+1]) for j in range(len(patch_nums_per_level))] + + + x_BLC = torch.zeros((bs,patch_nums_per_level_acc[-1],last_stage.shape[-1])).to(last_stage.device) + x_BLC[:,:1,:] = last_stage + l_end = x_BLC.shape[1] + + + for index in index_list: + + # add noise + noise = torch.randn_like(x_BLC).to(x_BLC.device) + x_BLC[:,patch_nums_per_level_acc[index]:,:] = noise[:,patch_nums_per_level_acc[index]:,:] + + # cat x_BLC_lq + x_BLC = torch.cat([x_BLC,x_BLC_lq],dim = 1) + + # add time_embedding + index_tensor = torch.full((bs,),index).to(x_BLC.device) + t_emb = dist.timestep_embedding(index_tensor, self.C//4, repeat_only=False) + t_emb = self.time_embed(t_emb) + + x_BLC = x_BLC + t_emb + + layer_idx = 0 + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad=0) + for m in b.module: + ### need to change scale_ind = si + x_BLC = m(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=None, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=0) + ### need to change + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + x_BLC = cfg * x_BLC[:B] + (1-cfg) * x_BLC[B:] + x_BLC = torch.cat((x_BLC, x_BLC), 0) + layer_idx += 1 + + ### need to change + if (cfg != 1) and add_cfg_on_logits: + logits_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD).mul(1/tau_list[0]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(x_BLC[:B, :l_end], cond_BD[:B]).mul(1/tau_list[0]) + + if index == index_list[-1]: + logits_final = logits_BlV + else: + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV= logits_BlV.reshape(tmp_bs, -1, 2) + idx_Bld = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + cum_var_input = 0 + x_BLC_wo_prefix = [] + idx_Bld_list = list(torch.split(idx_Bld,patch_nums_per_level,dim=1)) + for si, bit_indices in enumerate(idx_Bld_list): + + _, _, d_vae = bit_indices.shape + bit_indices = bit_indices.reshape((B,vae_scale_schedule[si][0],vae_scale_schedule[si][1],vae_scale_schedule[si][2],d_vae)) + + quantized = vae.quantizer.lfq.indices_to_codes(bit_indices, label_type='bit_label') + quantized_up = F.interpolate(quantized, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + cum_var_input += quantized_up + + if si < len(vae_scale_schedule)-1: + this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_up) + if self.apply_spatial_patchify: + this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) + x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) + + x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1) + x_BLC = torch.cat((last_stage, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + assert x_BLC.shape[1] == l_end + + + img = self.logits_to_img_discrete(logits_BlV_all=logits_final, + vae=vae, + scale_schedule=scale_schedule, + top_k=top_k, + top_p=top_p, + g_seed=g_seed) + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + + return None,None,img + + @for_visualize + def vis_key_params(self, ep): + return + + def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False): + for k in state_dict: + if 'cfg_uncond' in k: + old, new = state_dict[k], self.cfg_uncond.data + min_tlen = min(old.shape[0], new.shape[0]) + if min_tlen == old.shape[0]: + state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:])) + else: + state_dict[k] = old[:min_tlen] + + for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'): + state_dict.pop(buf_name, None) + if hasattr(self, buf_name): + state_dict[buf_name] = getattr(self, buf_name) + + return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) + + def special_init( + self, + aln_init: float, + aln_gamma_init: float, + scale_head: float, + scale_proj: int, + ): + # init head's norm + if isinstance(self.head_nm, AdaLNBeforeHead): + self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head + if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: + self.head_nm.ada_lin[-1].bias.data.zero_() + + # init head's proj + if scale_head >= 0: + if isinstance(self.head, nn.Linear): + self.head.weight.data.mul_(scale_head) + self.head.bias.data.zero_() + elif isinstance(self.head, nn.Sequential): + self.head[-1].weight.data.mul_(scale_head) + self.head[-1].bias.data.zero_() + + depth = len(self.unregistered_blocks) + for block_idx, sab in enumerate(self.unregistered_blocks): + sab: Union[SelfAttnBlock, CrossAttnBlock] + # init proj + scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx)) + if scale_proj == 1: + if self.t2i: + sab.sa.proj.weight.data.mul_(scale) + sab.ca.proj.weight.data.mul_(scale) + else: + sab.attn.proj.weight.data.mul_(scale) + sab.ffn.fc2.weight.data.mul_(scale) + # if sab.using_swiglu: + # nn.init.ones_(sab.ffn.fcg.bias) + # nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5) + + # init ada_lin + if hasattr(sab, 'ada_lin'): + lin = sab.ada_lin[-1] + lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma + lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift + if hasattr(lin, 'bias') and lin.bias is not None: + lin.bias.data.zero_() + elif hasattr(sab, 'ada_gss'): + sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma + sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift + + def extra_repr(self): + return f'drop_path_rate={self.drop_path_rate}' + + def get_layer_id_and_scale_exp(self, para_name: str): + raise NotImplementedError + +#CAR code +class ControlConditionEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + +class ControlConditionEmbedding_patch_size_32(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1,stride=2) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + +class FP32_Layernorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), + self.eps).to(origin_dtype) + +class CInfinity(Infinity): + def __init__( + self, vae_local, + text_channels=0, text_maxlen=0, # text-cond generation + selecting_idx=None, # class-cond generation + embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture + drop_rate=0., drop_path_rate=0., # drop out and drop path + norm_eps=1e-6, rms_norm=False, # norm layer + shared_aln=False, head_aln=True, # adaptive norm + cond_drop_rate=0.1, # for classifier-free guidance + rand_uncond=False, + cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, + raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), + head_depth=1, + top_p=0.0, top_k=0.0, + customized_flash_attn=False, fused_mlp=False, fused_norm=False, + block_chunks=1, + checkpointing=None, + pad_to_multiplier=0, + use_flex_attn=False, + batch_size=2, + add_lvl_embeding_only_first_block=1, + use_bit_label=1, + rope2d_each_sa_layer=0, + rope2d_normalized_by_hw=0, + pn=None, + train_h_div_w_list=None, + video_frames=1, + always_training_scales=20, + apply_spatial_patchify = 0, + inference_mode=False, + + ): + super(CInfinity,self).__init__(vae_local, + text_channels, text_maxlen, # text-cond generation + selecting_idx, # class-cond generation + embed_dim, depth, num_heads, mlp_ratio, # model's architecture + drop_rate, drop_path_rate, # drop out and drop path + norm_eps, rms_norm, # norm layer + shared_aln, head_aln, # adaptive norm + cond_drop_rate, # for classifier-free guidance + rand_uncond, + cross_attn_layer_scale, nm0, tau, cos_attn, swiglu, + raw_scale_schedule, + head_depth, + top_p, top_k, + customized_flash_attn, fused_mlp, fused_norm, + block_chunks, + checkpointing, + pad_to_multiplier, + use_flex_attn, + batch_size, + add_lvl_embeding_only_first_block, + use_bit_label, + rope2d_each_sa_layer, + rope2d_normalized_by_hw, + pn, + train_h_div_w_list, + video_frames, + always_training_scales, + apply_spatial_patchify, + inference_mode,) + + + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.car_var_conv = nn.Conv2d(self.C, self.C, kernel_size=conv_in_kernel, padding=conv_in_padding) + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + self.drop_path_rate = drop_path_rate + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + if fused_norm: + fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm + if fused_norm_func is not None: # pre-compile + B = 2 + x = torch.randn(B, 1, self.C).requires_grad_(True) + scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward() + del B, x, scale, shift + else: + fused_norm_func = None + + self.car_unregistered_blocks = [] + for block_idx in range(depth//2): + block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( + embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, + num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, + swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, + checkpointing_sa_only=self.checkpointing == 'self-attn', + use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, + ) + self.car_unregistered_blocks.append(block) + + if self.num_block_chunks == 1: + self.car_blocks = nn.ModuleList(self.car_unregistered_blocks) + else: + self.car_block_chunks = nn.ModuleList() + for i in range(self.num_block_chunks//2): + self.car_block_chunks.append(MultipleLayers(self.car_unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) + + car_norm_layer = FP32_Layernorm + car_skip_norm = [] + car_skip_linear = [] + + for _ in range(depth // 2): + car_skip_norm.append(car_norm_layer(2 * self.C, elementwise_affine=True, eps=1e-6)) + car_skip_linear.append(nn.Linear(2 * self.C, self.C)) + + # for _ in range(depth // 2): + # car_skip_norm.append(car_norm_layer(self.C, elementwise_affine=True, eps=1e-6)) + # car_skip_linear.append(nn.Linear(self.C, self.C)) + + self.car_skip_norm = nn.ModuleList(car_skip_norm) + self.car_skip_linear = nn.ModuleList(car_skip_linear) + + + def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False,x_BLC_w_prefix_lq=None, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.float() + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) + + + # #car_input code + # control_f = [] + # if control_tensors is not None: + # assert control_tensors[0].shape[0] == B + # for control_tensor in control_tensors: + # control_i = self.car_control_convs(control_tensor) + # control_f.append(control_i) + # car_input = [] + # var_x = sos.transpose(1, 2).contiguous().reshape(B, self.C, self.raw_scale_schedule[0], self.raw_scale_schedule[0]) + # var_x = self.car_var_conv(var_x) + # car_x = var_x + control_f[0] + # car_x = car_x.view(B, self.C, -1).transpose(1, 2).contiguous() + # car_input.append(car_x) + # for si, (pn, var_input) in enumerate(zip(self.raw_scale_schedule[1:], x_BLC_wo_prefix)): + # var_x = self.word_embed(var_input.float()) + # var_x = var_x.transpose(1, 2).contiguous().reshape(B, self.C, pn, pn) + # var_x = self.car_var_conv(var_x) + # car_x = var_x + control_f[si + 1] + # car_x = car_x.view(B, self.C, -1).transpose(1, 2).contiguous() + # car_input.append(car_x) + # car_input = torch.cat(car_input, dim=1) + + # x_BLC_lq = x_BLC_lq + x_BLC + # note important only for 512*512 + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) + x_BLC_list = list(torch.split(x_BLC,patch_nums_per_level,dim=1)) + x_BLC_lq_list_new = [] + CVae = x_BLC.shape[-1] + for si, (pn2, var_x) in enumerate(zip(patch_nums_per_level, x_BLC_list)): + pn = int(pn2**0.5) + var_x = var_x.transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) + var_x = self.car_var_conv(var_x) + car_x = var_x + x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) + car_x = car_x.view(B, CVae, -1).transpose(1, 2).contiguous() + x_BLC_lq_list_new.append(car_x) + x_BLC_lq = torch.cat(x_BLC_lq_list_new,dim=1) + + + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + #note x_BLC_lq padding + x_BLC_lq = F.pad(x_BLC_lq, (0, 0, 0, need_to_pad)) + + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + #note x_BLC_lq padding + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + x_BLC_lq = F.pad(x_BLC_lq,(0, 0, 0, need_to_pad)) + + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + + control_residual_f = [] + if self.num_block_chunks == 1: + for i, b in enumerate(self.car_blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC_lq = torch.utils.checkpoint.checkpoint(b, x_BLC_lq, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC_lq = b(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + control_residual_f.append(x_BLC_lq) + else: + for i, chunk in enumerate(self.car_block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + x_BLC_lq = chunk(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + control_residual_f.append(x_BLC_lq) + + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + + if i >= len(self.blocks) // 2: + con_f = control_residual_f.pop() + cat = torch.cat([x_BLC, con_f], dim=-1) + cat = self.car_skip_norm[i - len(self.blocks) // 2](cat) + x_BLC = self.car_skip_linear[i - len(self.blocks) // 2](cat) + + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + + if i >= len(self.block_chunks) // 2: + con_f = control_residual_f.pop() + cat = torch.cat([x_BLC, con_f], dim=-1) + cat = self.car_skip_norm[i - len(self.block_chunks) // 2](cat) + x_BLC = self.car_skip_linear[i - len(self.block_chunks) // 2](cat) + + # if i >= len(self.block_chunks) // 2: + # con_f = control_residual_f.pop() + # # cat = torch.cat([x_BLC, con_f], dim=-1) + # con_f = self.car_skip_norm[i - len(self.block_chunks) // 2](con_f) + # con_f = self.car_skip_linear[i - len(self.block_chunks) // 2](con_f) + # x_BLC = x_BLC + con_f + + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + # [3. unpad the seqlen dim, and then get logits] + return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + @torch.no_grad() + def car_inference( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + x_BLC_w_prefix_lq=None + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + # print(f"kv_compact_un.shape {kv_compact_un.shape} self.cfg_uncond.shape {self.cfg_uncond.shape}") + # print(lens) + #my code + #FSDP flattening + # if self.cfg_uncond.ndim ==1: + # last_dim = kv_compact_un.shape[-1] + # cfg_uncond = self.cfg_uncond.reshape(-1,last_dim) + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + ##### + x_BLC_w_prefix_lq = x_BLC_w_prefix_lq.expand(bs,-1,-1) + x_BLC_lq = self.word_embed(self.norm0_ve(x_BLC_w_prefix_lq)) + ##### + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + ##### + if inference_mode: + for b in self.car_unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.car_block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + ##### + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + ##### + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) + # for si, (pn2, var_x) in enumerate(zip(patch_nums_per_level, x_BLC_list)): + # pn = int(pn2**0.5) + # var_x = var_x.transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) + # var_x = self.car_var_conv(var_x) + # car_x = var_x + x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) + # car_x = car_x.view(B, CVae, -1).transpose(1, 2).contiguous() + # x_BLC_lq_list_new.append(car_x) + # x_BLC_lq = torch.cat(x_BLC_lq_list_new,dim=1) + ##### + for si, pn in enumerate(scale_schedule): # si: i-th segment + cfg = cfg_list[si] + if si >= trunk_scale: + break + cur_L += np.array(pn).prod() + + need_to_pad = 0 + attn_fn = None + if self.use_flex_attn: + # need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier + # if need_to_pad: + # last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad)) + attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) + + # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item' + layer_idx = 0 + ##### + control_residual_f = [] + if x_BLC_lq_list is not None: + last_stage_channel = last_stage.shape[-1] + control_x = x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(bs, last_stage_channel, pn[1], pn[2]) + var_x = last_stage.transpose(1, 2).contiguous().reshape(bs, last_stage_channel, pn[1], pn[2]) + var_x = self.car_var_conv(var_x) + control_x = var_x + control_x + control_x = control_x.view(bs, last_stage_channel, -1).transpose(1, 2) + # for cb in self.car_blocks: + # control_x = cb(x=control_x, cond_BD=cond_BD_or_gss, attn_bias=None) + # control_residual_f.append(control_x) + for i, chunk in enumerate(self.car_block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + control_x = self.add_lvl_embeding(control_x, si, scale_schedule, need_to_pad=need_to_pad) + if not self.add_lvl_embeding_only_first_block: + control_x = self.add_lvl_embeding(control_x, si, scale_schedule, need_to_pad=need_to_pad) + # for m in chunk.module: #used for FSDP + # control_x = m(x=control_x, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) + control_x = chunk(x=control_x, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + #chunk(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + control_residual_f.append(control_x) + + ##### + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + if not self.add_lvl_embeding_only_first_block: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + + ##### + if block_idx >= len(self.block_chunks) // 2: + con_f = control_residual_f.pop() + cat = torch.cat([last_stage, con_f], dim=-1) + cat = self.car_skip_norm[block_idx - len(self.block_chunks) // 2](cat) + last_stage = self.car_skip_linear[block_idx - len(self.block_chunks) // 2](cat) + ##### + #my code + #for m in b.module: need to change + for m in b.module: #used for FSDP + last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] + last_stage = torch.cat((last_stage, last_stage), 0) + layer_idx += 1 + + if (cfg != 1) and add_cfg_on_logits: + # print(f'add cfg on add_cfg_on_logits') + logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + if vae_type != 0: + assert returns_vemb + if si < gt_leak: + idx_Bld = gt_ls_Bl[si] + else: + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + idx_Bld_list.append(idx_Bld) + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w] + if self.apply_spatial_patchify: # patchify operation + last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w] + last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w] + last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d] + else: + summed_codes += codes + else: + if si < gt_leak: + idx_Bl = gt_ls_Bl[si] + h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC + + # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1]) + h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) + ret.append(h_BChw if returns_vemb != 0 else idx_Bl) + idx_Bl_list.append(idx_Bl) + if si != num_stages_minus_1: + accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) + + if si != num_stages_minus_1: + last_stage = self.word_embed(self.norm0_ve(last_stage)) + last_stage = last_stage.repeat(bs//B, 1, 1) + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + ##### + if inference_mode: + for b in self.car_unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.car_block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + ##### + + if not ret_img: + return ret, idx_Bl_list, [] + + if vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + else: + img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return ret, idx_Bl_list, img + +class CInfinity2(Infinity): + def __init__( + self, vae_local, + text_channels=0, text_maxlen=0, # text-cond generation + selecting_idx=None, # class-cond generation + embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture + drop_rate=0., drop_path_rate=0., # drop out and drop path + norm_eps=1e-6, rms_norm=False, # norm layer + shared_aln=False, head_aln=True, # adaptive norm + cond_drop_rate=0.1, # for classifier-free guidance + rand_uncond=False, + cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False, + raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), + head_depth=1, + top_p=0.0, top_k=0.0, + customized_flash_attn=False, fused_mlp=False, fused_norm=False, + block_chunks=1, + checkpointing=None, + pad_to_multiplier=0, + use_flex_attn=False, + batch_size=2, + add_lvl_embeding_only_first_block=1, + use_bit_label=1, + rope2d_each_sa_layer=0, + rope2d_normalized_by_hw=0, + pn=None, + train_h_div_w_list=None, + video_frames=1, + always_training_scales=20, + apply_spatial_patchify = 0, + inference_mode=False, + + ): + super(CInfinity,self).__init__(vae_local, + text_channels, text_maxlen, # text-cond generation + selecting_idx, # class-cond generation + embed_dim, depth, num_heads, mlp_ratio, # model's architecture + drop_rate, drop_path_rate, # drop out and drop path + norm_eps, rms_norm, # norm layer + shared_aln, head_aln, # adaptive norm + cond_drop_rate, # for classifier-free guidance + rand_uncond, + cross_attn_layer_scale, nm0, tau, cos_attn, swiglu, + raw_scale_schedule, + head_depth, + top_p, top_k, + customized_flash_attn, fused_mlp, fused_norm, + block_chunks, + checkpointing, + pad_to_multiplier, + use_flex_attn, + batch_size, + add_lvl_embeding_only_first_block, + use_bit_label, + rope2d_each_sa_layer, + rope2d_normalized_by_hw, + pn, + train_h_div_w_list, + video_frames, + always_training_scales, + apply_spatial_patchify, + inference_mode,) + + ###need to change + self.car_control_convs = ControlConditionEmbedding(conditioning_embedding_channels=self.C) + + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.car_var_conv = nn.Conv2d(self.C, self.C, kernel_size=conv_in_kernel, padding=conv_in_padding) + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + self.drop_path_rate = drop_path_rate + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + if fused_norm: + fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm + if fused_norm_func is not None: # pre-compile + B = 2 + x = torch.randn(B, 1, self.C).requires_grad_(True) + scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True) + # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward() + del B, x, scale, shift + else: + fused_norm_func = None + + self.car_unregistered_blocks = [] + for block_idx in range(depth//2): + block = (CrossAttnBlock if self.t2i else SelfAttnBlock)( + embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer, + num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn, + swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func, + checkpointing_sa_only=self.checkpointing == 'self-attn', + use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, + ) + self.car_unregistered_blocks.append(block) + + if self.num_block_chunks == 1: + self.car_blocks = nn.ModuleList(self.car_unregistered_blocks) + else: + self.car_block_chunks = nn.ModuleList() + for i in range(self.num_block_chunks//2): + self.car_block_chunks.append(MultipleLayers(self.car_unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk)) + + car_norm_layer = FP32_Layernorm + car_skip_norm = [] + car_skip_linear = [] + for _ in range(depth // 2): + car_skip_norm.append(car_norm_layer(2 * self.C, elementwise_affine=True, eps=1e-6)) + car_skip_linear.append(nn.Linear(2 * self.C, self.C)) + self.car_skip_norm = nn.ModuleList(car_skip_norm) + self.car_skip_linear = nn.ModuleList(car_skip_linear) + + + def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False,x_BLC_lq=None, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + x_BLC_lq = x_BLC_lq.float() + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + + + + + # #car_input code + # control_f = [] + # if control_tensors is not None: + # assert control_tensors[0].shape[0] == B + # for control_tensor in control_tensors: + # control_i = self.car_control_convs(control_tensor) + # control_f.append(control_i) + # car_input = [] + # var_x = sos.transpose(1, 2).contiguous().reshape(B, self.C, self.raw_scale_schedule[0], self.raw_scale_schedule[0]) + # var_x = self.car_var_conv(var_x) + # car_x = var_x + control_f[0] + # car_x = car_x.view(B, self.C, -1).transpose(1, 2).contiguous() + # car_input.append(car_x) + # for si, (pn, var_input) in enumerate(zip(self.raw_scale_schedule[1:], x_BLC_wo_prefix)): + # var_x = self.word_embed(var_input.float()) + # var_x = var_x.transpose(1, 2).contiguous().reshape(B, self.C, pn, pn) + # var_x = self.car_var_conv(var_x) + # car_x = var_x + control_f[si + 1] + # car_x = car_x.view(B, self.C, -1).transpose(1, 2).contiguous() + # car_input.append(car_x) + # car_input = torch.cat(car_input, dim=1) + + # x_BLC_lq = x_BLC_lq + x_BLC + # note important only for 512*512 + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) + x_BLC_list = list(torch.split(x_BLC,patch_nums_per_level,dim=1)) + x_BLC_lq_list_new = [] + CVae = x_BLC.shape[-1] + for si, (pn2, var_x) in enumerate(zip(patch_nums_per_level, x_BLC_list)): + pn = int(pn2**0.5) + var_x = var_x.transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) + var_x = self.car_var_conv(var_x) + car_x = x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) + car_x = self.car_control_convs(car_x) + car_x = var_x + car_x + car_x = car_x.view(B, CVae, -1).transpose(1, 2).contiguous() + x_BLC_lq_list_new.append(car_x) + x_BLC_lq = torch.cat(x_BLC_lq_list_new,dim=1) + + + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + #note x_BLC_lq padding + x_BLC_lq = F.pad(x_BLC_lq, (0, 0, 0, need_to_pad)) + + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + #note x_BLC_lq padding + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + x_BLC_lq = F.pad(x_BLC_lq,(0, 0, 0, need_to_pad)) + + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + + control_residual_f = [] + if self.num_block_chunks == 1: + for i, b in enumerate(self.car_blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC_lq = torch.utils.checkpoint.checkpoint(b, x_BLC_lq, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC_lq = b(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + control_residual_f.append(x_BLC_lq) + else: + for i, chunk in enumerate(self.car_block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC_lq = self.add_lvl_embeding_for_x_BLC(x_BLC_lq, scale_schedule, need_to_pad) + x_BLC_lq = chunk(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + control_residual_f.append(x_BLC_lq) + + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + + if i >= len(self.blocks) // 2: + con_f = control_residual_f.pop() + cat = torch.cat([x_BLC, con_f], dim=-1) + cat = self.car_skip_norm[i - len(self.blocks) // 2](cat) + x_BLC = self.car_skip_linear[i - len(self.blocks) // 2](cat) + + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + + if i >= len(self.block_chunks) // 2: + con_f = control_residual_f.pop() + cat = torch.cat([x_BLC, con_f], dim=-1) + cat = self.car_skip_norm[i - len(self.block_chunks) // 2](cat) + x_BLC = self.car_skip_linear[i - len(self.block_chunks) // 2](cat) + + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + + # [3. unpad the seqlen dim, and then get logits] + return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size + + @torch.no_grad() + def forward_img_infinity(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]], + cfg_infer=False, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV + """ + label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k) + :return: logits BLV, V is vocab_size + """ + if cfg_infer: + return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) + + x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + B = x_BLC_wo_prefix.shape[0] + + # [1. get input sequence x_BLC] + with torch.amp.autocast('cuda', enabled=False): + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + # drop cond + total = 0 + for le in lens: + if random.random() < self.cond_drop_rate: + kv_compact[total:total+le] = self.cfg_uncond[:le] + total += le + must_on_graph = self.cfg_uncond[0, 0] * 0 + kv_compact = self.text_norm(kv_compact).contiguous() + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + kv_compact = self.text_proj_for_ca(kv_compact).contiguous() + kv_compact[0, 0] += must_on_graph + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32 + + sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1) + x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1) + # [1.1. pad the seqlen dim] + l_end = x_BLC.shape[1] + need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0 + + if self.customized_flash_attn: + Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end] + Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end] + attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen) + # todo: solve need_to_pad here + elif self.use_flex_attn: + if need_to_pad: + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0' + attn_bias_or_two_vector = None + else: + d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1) + dT = d.transpose(1, 2) # dT: 11L + attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end) + attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL + if need_to_pad: + attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf) + attn_bias[0, 0, l_end:, 0] = 0 + x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad)) + attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device) + + if self.use_flex_attn: + attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)] + else: + attn_fn = None + + # [2. block loop] + SelfAttnBlock.forward, CrossAttnBlock.forward + checkpointing_full_block = self.checkpointing == 'full-block' and self.training + if self.num_block_chunks == 1: + for i, b in enumerate(self.blocks): + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if checkpointing_full_block: + x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False) + else: + x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + else: + for i, chunk in enumerate(self.block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + if not self.add_lvl_embeding_only_first_block: + x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad) + x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + + # [3. unpad the seqlen dim, and then get logits] + logit_BlV = self.get_logits(x_BLC[:, :l_end], cond_BD) #note tau_list + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + + ##need write + # if vae_type != 0: + # img = vae.decode(summed_codes.squeeze(-3)) + # else: + # img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) + + # img = (img + 1) / 2 + # img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return # return logits BLV, V is vocab_size + + @torch.no_grad() + def car_inference( + self, + vae=None, + scale_schedule=None, + label_B_or_BLT=None, + B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None, + g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0, + returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False, + cfg_exp_k: float=0.0, cfg_insertion_layer=[-5], + vae_type=0, softmax_merge_topk=-1, ret_img=False, + trunk_scale=1000, + gt_leak=0, gt_ls_Bl=None, + inference_mode=False, + save_img_path=None, + sampling_per_bits=1, + x_BLC_lq=None + ): # returns List[idx_Bl] + if g_seed is None: rng = None + else: self.rng.manual_seed(g_seed); rng = self.rng + assert len(cfg_list) >= len(scale_schedule) + assert len(tau_list) >= len(scale_schedule) + + # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify, + # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w + if self.apply_spatial_patchify: + vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule] + else: + vae_scale_schedule = scale_schedule + + kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT + if any(np.array(cfg_list) != 1): + bs = 2*B + if not negative_label_B_or_BLT: + kv_compact_un = kv_compact.clone() + total = 0 + print(f"kv_compact_un.shape {kv_compact_un.shape} self.cfg_uncond.shape {self.cfg_uncond.shape}") + print(lens) + #my code + #FSDP flattening + # if self.cfg_uncond.ndim ==1: + # last_dim = kv_compact_un.shape[-1] + # cfg_uncond = self.cfg_uncond.reshape(-1,last_dim) + for le in lens: + kv_compact_un[total:total+le] = (self.cfg_uncond)[:le] + total += le + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0) + else: + kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT + kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0) + cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0) + max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un) + else: + bs = B + + kv_compact = self.text_norm(kv_compact) + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096] + kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096] + ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k + last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) + ##### + x_BLC_lq = x_BLC_lq.expand(bs,-1,-1) + # x_BLC_lq = torch.cat((last_stage, self.word_embed(self.norm0_ve(x_BLC_wo_prefix_lq))), dim=1) + ##### + + with torch.amp.autocast('cuda', enabled=False): + cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images + idx_Bl_list, idx_Bld_list = [], [] + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + ##### + if inference_mode: + for b in self.car_unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.car_block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True) + ##### + abs_cfg_insertion_layers = [] + add_cfg_on_logits, add_cfg_on_probs = False, False + leng = len(self.unregistered_blocks) + for item in cfg_insertion_layer: + if item == 0: # add cfg on logits + add_cfg_on_logits = True + elif item == 1: # add cfg on probs + add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs + elif item < 0: # determine to add cfg at item-th layer's output + assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}' + abs_cfg_insertion_layers.append(leng+item) + else: + raise ValueError(f'cfg_insertion_layer: {item} is not valid') + + num_stages_minus_1 = len(scale_schedule)-1 + summed_codes = 0 + ##### + patch_nums_per_level = [pn[0]*pn[1]*pn[2] for pn in scale_schedule] # note important pn[0]==1? + x_BLC_lq_list = list(torch.split(x_BLC_lq,patch_nums_per_level,dim=1)) + # for si, (pn2, var_x) in enumerate(zip(patch_nums_per_level, x_BLC_list)): + # pn = int(pn2**0.5) + # var_x = var_x.transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) + # var_x = self.car_var_conv(var_x) + # car_x = var_x + x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(B, CVae, pn, pn) + # car_x = car_x.view(B, CVae, -1).transpose(1, 2).contiguous() + # x_BLC_lq_list_new.append(car_x) + # x_BLC_lq = torch.cat(x_BLC_lq_list_new,dim=1) + ##### + for si, pn in enumerate(scale_schedule): # si: i-th segment + cfg = cfg_list[si] + if si >= trunk_scale: + break + cur_L += np.array(pn).prod() + + need_to_pad = 0 + attn_fn = None + if self.use_flex_attn: + # need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier + # if need_to_pad: + # last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad)) + attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None) + + # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item' + layer_idx = 0 + ##### + control_residual_f = [] + if x_BLC_lq_list is not None: + last_stage_channel = last_stage.shape[-1] + control_x = x_BLC_lq_list[si].transpose(1, 2).contiguous().reshape(bs, last_stage_channel, pn[1], pn[2]) + control_x = self.car_control_convs(control_x) + var_x = last_stage.transpose(1, 2).contiguous().reshape(bs, last_stage_channel, pn[1], pn[2]) + var_x = self.car_var_conv(var_x) + control_x = var_x + control_x + control_x = control_x.view(bs, last_stage_channel, -1).transpose(1, 2) + # for cb in self.car_blocks: + # control_x = cb(x=control_x, cond_BD=cond_BD_or_gss, attn_bias=None) + # control_residual_f.append(control_x) + for i, chunk in enumerate(self.car_block_chunks): # this path + if self.add_lvl_embeding_only_first_block and i == 0: + control_x = self.add_lvl_embeding(control_x, si, scale_schedule, need_to_pad=need_to_pad) + if not self.add_lvl_embeding_only_first_block: + control_x = self.add_lvl_embeding(control_x, si, scale_schedule, need_to_pad=need_to_pad) + for m in chunk.module: + control_x = m(x=control_x, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) + #control_x = chunk(x=control_x, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid) + #chunk(x=x_BLC_lq, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid) + control_residual_f.append(control_x) + + ##### + for block_idx, b in enumerate(self.block_chunks): + # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int + if self.add_lvl_embeding_only_first_block and block_idx == 0: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + if not self.add_lvl_embeding_only_first_block: + last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + + ##### + if block_idx >= len(self.block_chunks) // 2: + con_f = control_residual_f.pop() + cat = torch.cat([last_stage, con_f], dim=-1) + cat = self.car_skip_norm[block_idx - len(self.block_chunks) // 2](cat) + last_stage = self.car_skip_linear[block_idx - len(self.block_chunks) // 2](cat) + ##### + #my code + for m in b.module: + # for m in b.module.module: #used for FSDP + last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) + if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): + # print(f'add cfg={cfg} on {layer_idx}-th layer output') + last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:] + last_stage = torch.cat((last_stage, last_stage), 0) + layer_idx += 1 + + if (cfg != 1) and add_cfg_on_logits: + # print(f'add cfg on add_cfg_on_logits') + logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si]) + logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:] + else: + logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si]) + + if self.use_bit_label: + tmp_bs, tmp_seq_len = logits_BlV.shape[:2] + logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2) + idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1) + else: + idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0] + if vae_type != 0: + assert returns_vemb + if si < gt_leak: + idx_Bld = gt_ls_Bl[si] + else: + assert pn[0] == 1 + idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d] + if self.apply_spatial_patchify: # unpatchify operation + idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w] + idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w] + idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d] + idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d] + + idx_Bld_list.append(idx_Bld) + codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + if si != num_stages_minus_1: + summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up) + last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w] + last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w] + if self.apply_spatial_patchify: # patchify operation + last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w] + last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w] + last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d] + else: + summed_codes += codes + else: + if si < gt_leak: + idx_Bl = gt_ls_Bl[si] + h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC + + # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1]) + h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) + ret.append(h_BChw if returns_vemb != 0 else idx_Bl) + idx_Bl_list.append(idx_Bl) + if si != num_stages_minus_1: + accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule) + + if si != num_stages_minus_1: + last_stage = self.word_embed(self.norm0_ve(last_stage)) + last_stage = last_stage.repeat(bs//B, 1, 1) + + if inference_mode: + for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + ##### + if inference_mode: + for b in self.car_unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False) + else: + assert self.num_block_chunks > 1 + for block_chunk_ in self.car_block_chunks: + for module in block_chunk_.module.module: + (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False) + ##### + + if not ret_img: + return ret, idx_Bl_list, [] + + if vae_type != 0: + img = vae.decode(summed_codes.squeeze(-3)) + else: + img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True) + + img = (img + 1) / 2 + img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)) + return ret, idx_Bl_list, img + +def sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l) + B, l, V = logits_BlV.shape + if top_k > 0: + top_k = min(top_k, V) + idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True) + logits_BlV.masked_fill_(idx_to_remove, -torch.inf) + if top_p > 0: + sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False) + sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p) + sorted_idx_to_remove[..., -1:] = False + logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf) + # sample (have to squeeze cuz multinomial can only be used on 2D tensor) + replacement = num_samples >= 0 + num_samples = abs(num_samples) + return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples) + +def sample_with_top_k_top_p_also_modifying_logits_( + logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1 +) -> torch.Tensor: # return idx, shaped (B, l) + B, l, V = logits_BlV.shape + logits_BlV = logits_BlV.clone() + + if top_k > 0: + top_k = min(top_k, V) + idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True) + logits_BlV = logits_BlV.masked_fill(idx_to_remove, -torch.inf) + + if top_p > 0: + sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False) + sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) <= (1 - top_p) + sorted_idx_to_remove[..., -1:] = False + logits_BlV = logits_BlV.masked_fill( + sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf + ) + + replacement = num_samples >= 0 + num_samples = abs(num_samples) + return torch.multinomial( + logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng + ).view(B, l, num_samples) + + +# def sample_with_top_k_top_p_also_modifying_logits_differentiable( +# logits_BlV: torch.Tensor, +# top_k: int = 0, +# top_p: float = 0.0, +# rng=None, +# num_samples=1 +# ) -> torch.Tensor: +# B, l, V = logits_BlV.shape +# probs = logits_BlV.softmax(dim=-1) # 可导的概率分布 + +# # 前向传播取离散索引 +# _, indices = probs.max(dim=-1) # (B, l) + +# # STE技巧:前向传播用离散索引,反向传播用概率梯度 +# one_hot = torch.zeros_like(probs).scatter_(-1, indices.unsqueeze(-1), 1.0) # (B, l, V) +# ste_vals = probs + (one_hot - probs).detach() # STE桥接 + +# # 返回索引(实际使用时可只返回indices,ste_vals仅用于梯度) +# return ste_vals.argmax(dim=-1) # (B, l) + +def sampling_with_top_k_top_p_also_inplace_modifying_probs_(probs_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l) + B, l, V = probs_BlV.shape + if top_k > 0: + top_k = min(top_k, V) + idx_to_remove = probs_BlV < probs_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True) + probs_BlV.masked_fill_(idx_to_remove, 0) + if top_p > 0: + sorted_probs, sorted_idx = probs_BlV.sort(dim=-1, descending=False) + sorted_idx_to_remove = sorted_probs.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p) + sorted_idx_to_remove[..., -1:] = False + probs_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), 0) + # sample (have to squeeze cuz multinomial can only be used on 2D tensor) + probs_BlV = probs_BlV / probs_BlV.sum(-1, keepdims=True) + replacement = num_samples >= 0 + num_samples = abs(num_samples) + return torch.multinomial(probs_BlV.view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples) + + +def get_params_num(d, w, mlp): + m = round(mlp * w / 256) * 256 + s = d * (w**2 * 8 + w*m * 2) # sa+ca, mlp + s += w**2 * 6 # saln + s += 4096 * w # pred + s += 32 * w # we + + Ct5 = 4096 + s += Ct5*w * 4 # T5 attn pool + s += Ct5*w + w*w # T5 mlp + return f'{s/1e9:.2f}B' + + +TIMM_KEYS = {'img_size', 'pretrained', 'pretrained_cfg', 'pretrained_cfg_overlay', 'global_pool'} + +@register_model +def infinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) + +@register_model +def cinfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return CInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) + +@register_model +def finfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return FInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) + +@register_model +def fainfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return FAInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) + +@register_model +def ainfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return AInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) + +@register_model +def binfinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return BInfinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) + +@register_model +def infinity_20b(depth=58, embed_dim=4608, num_heads=4608//128, drop_path_rate=0.25, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) + +# model configuration for scaling Infinity transformer +@register_model +def infinity_layer12(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, **kwargs): + return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) +@register_model +def infinity_layer16(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, **kwargs): + return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) +@register_model +def infinity_layer24(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, **kwargs): + return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) +@register_model +def infinity_layer32(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, **kwargs): + return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) +@register_model +def infinity_layer40(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, **kwargs): + return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS}) +@register_model +def infinity_layer48(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, **kwargs): + return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})