Spaces:
Running on Zero
Running on Zero
| import math | |
| from math import pi | |
| from dataclasses import dataclass | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| from torch import einsum, broadcast_tensors, Tensor | |
| import torch.nn.functional as F | |
| from torch.nn import Module | |
| from torch.amp import autocast | |
| from torch import nn | |
| from einops import rearrange, repeat, reduce | |
| import einops | |
| import numpy as np | |
| import os | |
| import copy | |
| import warnings | |
| from utils import print0 | |
| from itertools import chain | |
| # Well Inited Linear | |
| def weight_init(shape, mode, fan_in, fan_out): | |
| if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) | |
| if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) | |
| if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) | |
| if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape) | |
| if mode == 'default': return np.sqrt(1 / fan_in) * (torch.rand(*shape) * 2 - 1) # nn.Linear default | |
| if mode == 'trunc_normal': return torch.nn.init.trunc_normal_(torch.empty(*shape), std=0.02) | |
| if mode == 'uniform': return torch.rand() * 2 - 1 | |
| raise ValueError(f'Invalid init mode "{mode}"') | |
| class Linear(torch.nn.Module): | |
| def __init__(self, in_features, out_features, bias=True, init_mode='trunc_normal', init_weight=1, init_bias=0): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) | |
| self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) | |
| self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None | |
| def forward(self, x): | |
| x = x @ self.weight.to(x.dtype).t() | |
| if self.bias is not None: | |
| x = x.add_(self.bias.to(x.dtype)) | |
| return x | |
| class AdaRMSNorm(nn.Module): | |
| def __init__(self,dim, cond_dim=None, elementwise_affine=True, cond_based_affine=True, centering=False, eps=1e-6): | |
| super(AdaRMSNorm, self).__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.cond_based_affine = cond_based_affine | |
| self.elementwise_affine = elementwise_affine | |
| self.scale = dim ** 0.5 | |
| assert not(cond_dim is None and cond_based_affine), 'cond_dim must be provided if cond_based_affine is True' | |
| if elementwise_affine: | |
| if cond_based_affine and cond_dim is not None: | |
| self.affine = Linear(cond_dim, dim, init_weight=1e-5) | |
| else: | |
| self.weight = nn.Parameter(torch.zeros(self.dim)) | |
| else: | |
| self.register_parameter("weight", None) | |
| def forward(self, x, cond_emb=None): | |
| with torch.amp.autocast('cuda',enabled=False): | |
| output = F.normalize(x.float(), dim=(-1)) * self.scale | |
| if self.elementwise_affine: | |
| weight = self.affine(cond_emb).unsqueeze(1) if self.cond_based_affine and cond_emb is not None else self.weight | |
| output = output.mul(1. + weight.float()) | |
| return output.type_as(x) | |
| class AdaLN(nn.Module): | |
| def __init__(self,dim, cond_dim=None, elementwise_affine=True, cond_based_affine=True, bias=True, eps=1e-6): | |
| super(AdaLN, self).__init__() | |
| self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine and not cond_based_affine, eps=eps) | |
| self.cond_based_affine = cond_based_affine | |
| self.bias = bias | |
| assert not(cond_dim is None and cond_based_affine), 'cond_dim must be provided if cond_based_affine is True' | |
| self.affine = Linear(cond_dim, 2 * dim if bias else dim, init_weight=1e-5) if cond_based_affine and cond_dim is not None else None | |
| def forward(self, x, cond_emb=None): | |
| x = self.norm(x) | |
| if self.cond_based_affine: | |
| if self.bias: | |
| shift, scale = self.affine(cond_emb).unsqueeze(1).chunk(2, dim=-1) | |
| x = x.mul(1. + scale).add_(shift) | |
| else: | |
| scale = self.affine(cond_emb).unsqueeze(1) | |
| x = x.mul(1. + scale) | |
| return x | |
| class FluxRopeEMB(nn.Module): | |
| def __init__(self, theta, axes_dim: list[int]): | |
| super().__init__() | |
| # theta can be a scalar (applied to all axes) or a list (per-axis theta) | |
| if isinstance(theta, (int, float)): | |
| self.theta = [theta] * len(axes_dim) | |
| else: | |
| assert len(theta) == len(axes_dim), \ | |
| f"rope theta list length {len(theta)} must match axes_dim length {len(axes_dim)}" | |
| self.theta = list(theta) | |
| self.axes_dim = axes_dim | |
| def forward(self, ids: Tensor) -> Tensor: | |
| emb = torch.cat( | |
| [rope(ids[..., i], self.axes_dim[i], self.theta[i]) for i in range(len(self.axes_dim))], | |
| dim=-3, | |
| ) | |
| return emb.unsqueeze(1) | |
| def rope(pos: Tensor, dim: int, theta: int) -> Tensor: | |
| assert dim % 2 == 0 | |
| # Accept both (n,) and (b, n) positions. | |
| pos=pos.float() | |
| if pos.ndim == 1: | |
| pos = pos.unsqueeze(0) | |
| scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim | |
| omega = 1.0 / (theta**scale) | |
| out = torch.einsum("...n,d->...nd", pos, omega) | |
| out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) | |
| out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | |
| return out.float() | |
| def apply_rope(x, freqs_cis): | |
| """Apply RoPE to ``x`` ``[B, H, L, D]`` using ``freqs_cis`` ``[*, 1, S, D/2, 2, 2]``.""" | |
| if freqs_cis is None: | |
| raise ValueError("freqs_cis is None but RoPE is enabled. Did you forget to pass q_pe/k_pe?") | |
| b, h, l, d = x.shape | |
| if d % 2 != 0: | |
| raise ValueError(f"RoPE requires last dim even, got D={d}") | |
| x_dtype = x.dtype | |
| x_float = x.float().reshape(b, h, l, d // 2, 2) | |
| freqs = freqs_cis.to(device=x.device, dtype=x_float.dtype) | |
| # Matrix-vector multiply using columns: y = M[...,0]*x0 + M[...,1]*x1 | |
| x0 = x_float[..., 0] | |
| x1 = x_float[..., 1] | |
| x_out = freqs[..., 0] * x0.unsqueeze(-1) + freqs[..., 1] * x1.unsqueeze(-1) # (..., 2) | |
| return x_out.reshape(b, h, l, d).to(dtype=x_dtype) | |
| # DropPath & FFN | |
| def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm | |
| if drop_prob == 0. or not training: return x | |
| keep_prob = 1 - drop_prob | |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |
| random_tensor = x.new_empty(shape).bernoulli_(keep_prob) | |
| if keep_prob > 0.0 and scale_by_keep: | |
| random_tensor.div_(keep_prob) | |
| return x * random_tensor | |
| class DropPath(nn.Module): # taken from timm | |
| def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| self.scale_by_keep = scale_by_keep | |
| def forward(self, x): | |
| return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) | |
| def extra_repr(self): | |
| return f'(drop_prob=...)' | |
| class FFN(nn.Module): | |
| def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., ffn_type='geglu', out_value=1.0, ffn_bias=True): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.ffn_type = ffn_type | |
| if self.ffn_type == 'geglu': | |
| self.fc1 = Linear(in_features, 2*hidden_features, bias=ffn_bias) | |
| self.act = nn.GELU(approximate='tanh') | |
| self.fc2 = Linear(hidden_features, out_features, bias=ffn_bias, init_weight=out_value) | |
| elif self.ffn_type == 'ffn': | |
| self.fc1 = Linear(in_features, hidden_features, bias=ffn_bias) | |
| self.act = nn.GELU(approximate='tanh') | |
| self.fc2 = Linear(hidden_features, out_features, bias=ffn_bias, init_weight=out_value) | |
| self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity() | |
| def forward(self, x): | |
| if self.ffn_type == 'geglu': | |
| gate, value = self.fc1(x).chunk(2, dim=-1) | |
| gated = self.act(value) * gate | |
| return self.drop(self.fc2(gated)) | |
| elif self.ffn_type == 'ffn': | |
| return self.drop(self.fc2(self.act(self.fc1(x)))) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, block_idx, embed_dim=768, num_heads=12, | |
| attn_drop=0., proj_drop=0., attn_norm=False, | |
| rope=False, out_value=1.0, attn_out_bias=True | |
| ): | |
| super().__init__() | |
| assert embed_dim % num_heads == 0 | |
| self.block_idx = block_idx | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.attn_norm = attn_norm | |
| self.rope = rope | |
| self.q_norm = AdaRMSNorm(self.head_dim,cond_based_affine=False) if self.attn_norm else None | |
| self.k_norm = AdaRMSNorm(self.head_dim,cond_based_affine=False) if self.attn_norm else None | |
| self.to_q = Linear(embed_dim, embed_dim, bias=False) | |
| self.to_kv = Linear(embed_dim, embed_dim * 2, bias=False) | |
| self.proj = Linear(embed_dim, embed_dim, bias=attn_out_bias, init_weight=out_value) | |
| self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity() | |
| self.attn_drop = attn_drop | |
| def forward(self, x, context_emb=None, causal=False, attn_bias=None, cache_kv=False, past_kvs=None,pe=None,ctx_pe=None, return_attn=False): | |
| B, L, C = x.shape | |
| q = self.to_q(x) | |
| q = einops.rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads) | |
| kv = self.to_kv(x) if context_emb is None else self.to_kv(context_emb) | |
| k, v = einops.rearrange(kv, 'b l (k h d) -> k b h l d', k=2, h=self.num_heads) | |
| q = self.q_norm(q) if self.q_norm is not None else q | |
| k = self.k_norm(k) if self.k_norm is not None else k | |
| if self.rope: | |
| q = apply_rope(q, pe) | |
| if context_emb is not None and ctx_pe is not None: | |
| k = apply_rope(k, ctx_pe) | |
| else: | |
| k = apply_rope(k, pe) | |
| if attn_bias is not None: | |
| attn_bias = attn_bias.unsqueeze(1).expand(B, self.num_heads, L, L) | |
| if past_kvs is not None: | |
| past_keys, past_values = past_kvs | |
| k = torch.cat((past_keys, k), dim=2) | |
| v = torch.cat((past_values, v), dim=2) | |
| if return_attn: | |
| scale = self.head_dim ** -0.5 | |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale | |
| if causal and past_kvs is None and attn_bias is None: | |
| L_q, L_k = q.size(2), k.size(2) | |
| causal_mask = torch.triu(torch.ones(L_q, L_k, device=q.device, dtype=torch.bool), diagonal=1) | |
| attn_weights.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf')) | |
| if attn_bias is not None: | |
| attn_weights = attn_weights + attn_bias | |
| attn_weights = F.softmax(attn_weights, dim=-1) | |
| oup = torch.matmul(attn_weights, v).transpose(1, 2).reshape(B, L, -1) | |
| else: | |
| attn_weights = None | |
| dropout_p = self.attn_drop if self.training else 0.0 | |
| oup = F.scaled_dot_product_attention( | |
| query=q, | |
| key=k, | |
| value=v, | |
| is_causal=causal and past_kvs is None and attn_bias is None, | |
| dropout_p=dropout_p, | |
| attn_mask=attn_bias | |
| ).transpose(1, 2).reshape(B, L, -1) | |
| out = self.proj_drop(self.proj(oup)) | |
| if return_attn: | |
| return (out, (k, v), attn_weights) if cache_kv else (out, attn_weights) | |
| else: | |
| return (out, (k, v)) if cache_kv else out | |
| class TransformerLayer(nn.Module): | |
| def __init__( | |
| self, block_idx, embed_dim, cond_dim, | |
| num_heads, mlp_ratio=8/3, drop=0., attn_drop=0., drop_path=0., attn_norm=False, cross_attn=False, rope=False, | |
| ffn_type='geglu', norm_layer=AdaRMSNorm, cond_based_affine=True, ffn_bias=True, attn_out_bias=True | |
| ): | |
| super(TransformerLayer, self).__init__() | |
| self.block_idx = block_idx | |
| self.cond_based_affine = cond_based_affine | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| out_value = 1.0 if cond_based_affine else 1e-5 | |
| self.adaln_1 = norm_layer(embed_dim, cond_dim, cond_based_affine=cond_based_affine) | |
| self.attn1 = Attention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_norm=attn_norm, rope=rope, out_value=out_value, attn_out_bias=attn_out_bias) | |
| self.gate1 = Linear(cond_dim, embed_dim, bias=ffn_bias, init_weight=1e-5) if cond_based_affine else None | |
| if cross_attn: | |
| self.has_cross_attn = True | |
| self.adaln_2 = norm_layer(embed_dim, cond_dim, cond_based_affine=cond_based_affine) | |
| self.adaln_ctx = norm_layer(embed_dim, cond_dim, cond_based_affine=cond_based_affine) | |
| self.attn2 = Attention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_norm=attn_norm, rope=rope, out_value=out_value, attn_out_bias=attn_out_bias) | |
| self.gate2 = Linear(cond_dim, embed_dim, bias=ffn_bias, init_weight=1e-5) if cond_based_affine else None | |
| else: | |
| self.has_cross_attn = False | |
| self.adaln_mlp = norm_layer(embed_dim, cond_dim, cond_based_affine=cond_based_affine) | |
| mlp_ratio = eval(mlp_ratio) if isinstance(mlp_ratio, str) else mlp_ratio | |
| self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, ffn_type=ffn_type, out_value=out_value, ffn_bias=ffn_bias) | |
| self.gate_mlp = Linear(cond_dim, embed_dim, bias=ffn_bias, init_weight=1e-5) if cond_based_affine else None | |
| def forward(self, x,context_emb=None, cond_emb=None, causal=False, attn_mask=None, cache_kv=False, past_kvs=None,pe=None,ctx_pe=None, return_attn=False): | |
| self_attn_bias = attn_mask if (attn_mask is not None and not self.has_cross_attn) else None | |
| self_causal = causal if self_attn_bias is None else False | |
| attn_out1 = self.attn1(self.adaln_1(x, cond_emb), context_emb=None,causal=self_causal, attn_bias=self_attn_bias, cache_kv=cache_kv, past_kvs=past_kvs,pe=pe,ctx_pe=None, return_attn=return_attn) | |
| if return_attn: | |
| if cache_kv: | |
| attn_out1, new_kvs, layer_attn_weights = attn_out1 | |
| else: | |
| attn_out1, layer_attn_weights = attn_out1 | |
| elif cache_kv: | |
| attn_out1, new_kvs = attn_out1[0], attn_out1[1] | |
| gate1 = self.gate1(cond_emb).unsqueeze(1) if self.gate1 is not None else 1.0 | |
| x = x + gate1 * self.drop_path(attn_out1) | |
| if context_emb is not None and self.has_cross_attn: | |
| gate2 = self.gate2(cond_emb).unsqueeze(1) if self.gate2 is not None else 1.0 | |
| x = x + gate2 * self.drop_path(self.attn2(self.adaln_2(x, cond_emb), context_emb=self.adaln_ctx(context_emb,cond_emb), causal=False, attn_bias=attn_mask,pe=pe,ctx_pe=ctx_pe)) | |
| gate_mlp = self.gate_mlp(cond_emb).unsqueeze(1) if self.gate_mlp is not None else 1.0 | |
| x = x + gate_mlp * self.drop_path(self.ffn(self.adaln_mlp(x, cond_emb))) | |
| if return_attn: | |
| return (x, new_kvs, layer_attn_weights) if cache_kv else (x, layer_attn_weights) | |
| elif cache_kv: | |
| return x, new_kvs | |
| else: | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| layer_num, | |
| input_dim, dim, output_dim, max_seq_len, heads, | |
| cond_input_dim, cond_dim, | |
| context_type='none', ctx_input_dim=None, ctx_max_seq_len=None, | |
| emb_dropout=0., drop=0., attn_drop=0., | |
| rope=False, abs_pos=False, | |
| attn_norm=False, causal=False, | |
| zero_out=False, noise_query = False, | |
| ffn_type='geglu', # ffn or geglu | |
| norm_layer='RMSNorm', # LayerNorm or RMSNorm | |
| mlp_ratio=8/3, | |
| rope_theta=2000, | |
| rope_ctx_theta=10000, | |
| rope_axes_dim=[32,32,32], | |
| rope_input_type='l', | |
| rope_hybrid_l_len=0, | |
| rope_ctx_type='none', | |
| out_norm=True, | |
| out_layer=True, | |
| out_act=False, | |
| input_bias=True, | |
| input_layer=True, | |
| use_label=True, | |
| ffn_bias=True, | |
| out_bias=True, | |
| attn_out_bias=True, | |
| bos_zero_rope=False, | |
| **params | |
| ): | |
| super(Transformer, self).__init__() | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| self.noise_query = noise_query | |
| self.cond_input_dim = cond_input_dim | |
| self.dim = dim | |
| self.causal = causal | |
| self.bos_zero_rope = bos_zero_rope | |
| if bos_zero_rope: | |
| print0("[Transformer] bos_zero_rope=True: BOS token will use zero RoPE position (pos=0)") | |
| self.concat_causal = (context_type == 'concat' and causal) | |
| if self.concat_causal: | |
| self.causal = False | |
| _ctx_len = ctx_max_seq_len | |
| _total = max_seq_len + _ctx_len | |
| _mask = torch.zeros(1, _total, _total) | |
| _mask[:, :_ctx_len, _ctx_len:] = float('-inf') | |
| _mask[:, _ctx_len:, _ctx_len:].masked_fill_( | |
| torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1), | |
| float('-inf')) | |
| self.register_buffer('concat_causal_mask', _mask) | |
| self.cross_attn = context_type=='cross_attn' | |
| print0(norm_layer) | |
| norm_layer = AdaRMSNorm if norm_layer=="RMSNorm" else AdaLN | |
| self.use_label = use_label | |
| self.layers = nn.ModuleList([ | |
| TransformerLayer(block_idx=i, embed_dim=dim, | |
| cond_dim=cond_dim, num_heads=heads, mlp_ratio=mlp_ratio, | |
| drop=drop, attn_drop=attn_drop, attn_norm=attn_norm, rope=rope, cross_attn=self.cross_attn, | |
| ffn_type=ffn_type, norm_layer=norm_layer, cond_based_affine=self.use_label, ffn_bias=ffn_bias, attn_out_bias=attn_out_bias) | |
| for i in range(layer_num) | |
| ]) | |
| out_weight = 1e-5 if zero_out else 1 | |
| scale = np.sqrt(dim) | |
| self.rope_input_type = rope_input_type | |
| self.rope_hybrid_l_len = rope_hybrid_l_len | |
| self.rope_ctx_type = rope_ctx_type | |
| self.out_norm = norm_layer(dim, cond_dim, cond_based_affine=self.use_label) if out_norm else None | |
| self.out_activate = nn.GELU(approximate='tanh') if out_act else None | |
| self.out = Linear(dim, output_dim, bias=out_bias, init_weight=out_weight) if out_layer else None | |
| if self.use_label: | |
| self.cond_proj = nn.Sequential(Linear(cond_input_dim, cond_dim), norm_layer(cond_dim, cond_based_affine=False), nn.GELU(approximate='tanh')) | |
| else: | |
| self.cond_proj = None | |
| self.max_seq_len = max_seq_len | |
| self.context_type = context_type # "concat, cross_attn, none" | |
| if abs_pos: | |
| if context_type == 'none': | |
| self.abs_pos = nn.Embedding(max_seq_len, dim) | |
| nn.init.trunc_normal_(self.abs_pos.weight, std=0.02) | |
| elif context_type == 'concat': | |
| self.abs_pos = nn.Embedding(max_seq_len + ctx_max_seq_len, dim) | |
| nn.init.trunc_normal_(self.abs_pos.weight, std=0.02) | |
| elif context_type == 'cross_attn': | |
| self.abs_pos = nn.Embedding(max_seq_len, dim) | |
| self.ctx_pos = nn.Embedding(ctx_max_seq_len, dim) | |
| nn.init.trunc_normal_(self.abs_pos.weight, std=0.02) | |
| nn.init.trunc_normal_(self.ctx_pos.weight, std=0.02) | |
| else: | |
| raise ValueError(f"Invalid context_type: {context_type}") | |
| else: | |
| self.abs_pos = None | |
| if rope: | |
| self.rope_emb = FluxRopeEMB(theta=rope_theta, axes_dim=rope_axes_dim) | |
| if context_type != 'none': | |
| self.ctx_rope_emb = FluxRopeEMB(theta=rope_ctx_theta, axes_dim=rope_axes_dim) | |
| else: | |
| self.rope_emb = None | |
| self.ctx_rope_emb = None | |
| self.input_layer = Linear(input_dim, dim, bias=input_bias) if input_layer else nn.Identity() | |
| if context_type != 'none': | |
| self.learnable_query = nn.Parameter(torch.empty(1, max_seq_len, dim)) | |
| nn.init.trunc_normal_(self.learnable_query, std=0.02) | |
| self.emb_dropout = nn.Dropout(p=emb_dropout) if emb_dropout > 0 else nn.Identity() | |
| def prepare_pos_ids(self, x_len, z_len, past_len,device): | |
| if self.context_type == 'none': | |
| if self.rope_input_type == 'hw': | |
| w = int(np.sqrt(self.max_seq_len)) | |
| idx = torch.arange(past_len, past_len + x_len, device=device) | |
| x_ids = ctx_ids = torch.stack([ | |
| (idx // w).float(), | |
| (idx % w).float(), | |
| torch.zeros(x_len, device=device, dtype=torch.float32), | |
| ], dim=1) | |
| elif self.rope_input_type == 'l': | |
| l = x_len | |
| x_coords = { | |
| "l": torch.arange(past_len, past_len + l, device=device, dtype=torch.float32), | |
| "h": torch.arange(1, device=device, dtype=torch.float32), | |
| "w": torch.arange(1, device=device, dtype=torch.float32), | |
| } | |
| x_ids = ctx_ids = torch.cartesian_prod( x_coords["l"], x_coords["h"], x_coords["w"]) | |
| elif self.rope_input_type == 'hybrid': | |
| l_len = self.rope_hybrid_l_len | |
| hw_total = self.max_seq_len - l_len | |
| h = w = int(np.sqrt(hw_total)) | |
| end_pos = past_len + x_len | |
| all_ids = [] | |
| sem_start = past_len | |
| sem_end = min(end_pos, l_len) | |
| if sem_start < sem_end: | |
| n = sem_end - sem_start | |
| all_ids.append(torch.stack([ | |
| torch.arange(sem_start, sem_end, device=device, dtype=torch.float32), | |
| torch.zeros(n, device=device, dtype=torch.float32), | |
| torch.zeros(n, device=device, dtype=torch.float32), | |
| ], dim=1)) | |
| vis_start = max(past_len, l_len) | |
| if vis_start < end_pos: | |
| n = end_pos - vis_start | |
| offset = vis_start - l_len | |
| idx = torch.arange(offset, offset + n, device=device) | |
| all_ids.append(torch.stack([ | |
| torch.zeros(n, device=device, dtype=torch.float32), | |
| (idx // w).float(), | |
| (idx % w).float(), | |
| ], dim=1)) | |
| x_ids = ctx_ids = torch.cat(all_ids, dim=0) | |
| elif self.context_type == 'concat': | |
| if self.rope_input_type == 'hw': | |
| h = w = int(np.sqrt(z_len)) | |
| x_coords = { | |
| "h": torch.arange(h, device=device, dtype=torch.float32), | |
| "w": torch.arange(w, device=device, dtype=torch.float32), | |
| "l": torch.arange(1, device=device, dtype=torch.float32), | |
| } | |
| elif self.rope_input_type == 'l': | |
| l = z_len | |
| x_coords = { | |
| "h": torch.arange(1, device=device, dtype=torch.float32), | |
| "w": torch.arange(1, device=device, dtype=torch.float32), | |
| "l": torch.arange(l, device=device, dtype=torch.float32), | |
| } | |
| if self.rope_ctx_type == 'hw': | |
| h = w = int(np.sqrt(x_len)) | |
| ctx_coords = { | |
| "h": torch.arange(h, device=device, dtype=torch.float32), | |
| "w": torch.arange(w, device=device, dtype=torch.float32), | |
| "l": torch.arange(1, device=device, dtype=torch.float32), | |
| } | |
| elif self.rope_ctx_type == 'l': | |
| l = x_len | |
| ctx_coords = { | |
| "h": torch.arange(1, device=device, dtype=torch.float32), | |
| "w": torch.arange(1, device=device, dtype=torch.float32), | |
| "l": torch.arange(l, device=device, dtype=torch.float32), | |
| } | |
| x_ids = torch.cartesian_prod(x_coords["h"], x_coords["w"], x_coords["l"]) | |
| ctx_ids = torch.cartesian_prod(ctx_coords["h"], ctx_coords["w"], ctx_coords["l"]) | |
| elif self.context_type == 'cross_attn': | |
| if self.rope_input_type == 'hw': | |
| h = w = int(np.sqrt(z_len)) | |
| x_coords = { | |
| "h": torch.arange(h, device=device, dtype=torch.float32), | |
| "w": torch.arange(w, device=device, dtype=torch.float32), | |
| "l": torch.arange(1, device=device, dtype=torch.float32), | |
| } | |
| elif self.rope_input_type == 'l': | |
| l = z_len | |
| x_coords = { | |
| "h": torch.arange(1, device=device, dtype=torch.float32), | |
| "w": torch.arange(1, device=device, dtype=torch.float32), | |
| "l": torch.arange(l, device=device, dtype=torch.float32), | |
| } | |
| if self.rope_ctx_type == 'hw': | |
| h = w = int(np.sqrt(x_len)) | |
| ctx_coords = { | |
| "h": torch.arange(h, device=device, dtype=torch.float32), | |
| "w": torch.arange(w, device=device, dtype=torch.float32), | |
| "l": torch.arange(1, device=device, dtype=torch.float32), | |
| } | |
| elif self.rope_ctx_type == 'l': | |
| l = x_len | |
| ctx_coords = { | |
| "h": torch.arange(1, device=device, dtype=torch.float32), | |
| "w": torch.arange(1, device=device, dtype=torch.float32), | |
| "l": torch.arange(l, device=device, dtype=torch.float32), | |
| } | |
| x_ids = torch.cartesian_prod(x_coords["h"], x_coords["w"], x_coords["l"]) | |
| ctx_ids = torch.cartesian_prod(ctx_coords["h"], ctx_coords["w"], ctx_coords["l"]) | |
| return x_ids, ctx_ids | |
| def forward(self, x, condition, attn_mask=None, cache_kv=False, past_kvs=None, return_attn=False): | |
| x_len = x.size(1) | |
| z_len = self.max_seq_len if self.context_type != 'none' else 0 | |
| if self.context_type == 'none': | |
| x, context_emb = self.input_layer(x), None | |
| elif self.context_type == 'concat': | |
| learnable_query = self.learnable_query.expand(x.size(0), -1, -1).to(x.device) | |
| x = self.input_layer(x) | |
| x = torch.cat((x, learnable_query), dim=1) | |
| x, context_emb = x, None | |
| else: | |
| learnable_query = self.learnable_query.expand(x.size(0), -1, -1).to(x.device) | |
| x, context_emb = learnable_query, self.input_layer(x) | |
| x = self.emb_dropout(x) | |
| cond_emb = self.cond_proj[0](condition) if (self.cond_proj is not None and condition is not None) else None | |
| cond_emb = self.emb_dropout(cond_emb) if cond_emb is not None else None | |
| cond_emb = self.cond_proj[1](cond_emb) if cond_emb is not None else None | |
| cond_emb = self.cond_proj[2](cond_emb) if cond_emb is not None else None | |
| if self.abs_pos is not None: | |
| if self.context_type == 'none': | |
| pos_ids = torch.arange(x_len, device=x.device) | |
| if past_kvs is not None: | |
| pos_ids = pos_ids + past_kvs[0][0].shape[2] | |
| x = x + self.abs_pos(pos_ids) | |
| elif self.context_type == 'concat': | |
| pos_ids = torch.arange(x_len+z_len, device=x.device) | |
| x = x + self.abs_pos(pos_ids) | |
| elif self.context_type == 'cross_attn': | |
| pos_ids = torch.arange(x_len, device=x.device) | |
| x = x + self.abs_pos(pos_ids) | |
| if context_emb is not None and self.ctx_pos is not None: | |
| context_emb = context_emb + self.ctx_pos(torch.arange(z_len, device=context_emb.device)) | |
| pe = ctx_pe = None | |
| if self.rope_emb is not None: | |
| past_len = 0 | |
| if past_kvs is not None: | |
| past_len = past_kvs[0][0].shape[2] | |
| x_ids, ctx_ids = self.prepare_pos_ids(x_len, z_len, past_len, x.device) | |
| pe = self.rope_emb(x_ids) | |
| if self.context_type != 'none': | |
| ctx_pe = self.ctx_rope_emb(ctx_ids) | |
| if self.context_type == 'concat': | |
| pe = torch.cat([ctx_pe, pe], dim=2) | |
| ctx_pe = None | |
| if self.bos_zero_rope and past_len == 0: | |
| pe = pe.clone() | |
| pe[:, :, 0, :, 0, 0] = 0 # cos = 0 | |
| pe[:, :, 0, :, 0, 1] = 0 # -sin = 0 | |
| pe[:, :, 0, :, 1, 0] = 0 # sin = 0 | |
| pe[:, :, 0, :, 1, 1] = 0 # cos = 0 | |
| attn_mask = self.concat_causal_mask if self.concat_causal else attn_mask | |
| kv_caches = [] | |
| all_attn_maps = [] if return_attn else None | |
| for i,layer in enumerate(self.layers): | |
| layer_output = layer(x, context_emb, cond_emb, self.causal, attn_mask, cache_kv, past_kvs[i] if past_kvs is not None else None, pe=pe, ctx_pe=ctx_pe, return_attn=return_attn) | |
| if return_attn: | |
| if cache_kv: | |
| x, new_kvs, layer_attn = layer_output | |
| kv_caches.append(new_kvs) | |
| else: | |
| x, layer_attn = layer_output | |
| all_attn_maps.append(layer_attn) | |
| elif cache_kv: | |
| x, new_kvs = layer_output[0], layer_output[1] | |
| kv_caches.append(new_kvs) | |
| else: | |
| x = layer_output | |
| with torch.amp.autocast('cuda', enabled=False): | |
| x = x.float() | |
| if cond_emb is not None: | |
| cond_emb = cond_emb.float() | |
| if self.out_norm is not None: | |
| x = self.out_norm(x, cond_emb) | |
| if self.out_activate is not None: | |
| x = self.out_activate(x) | |
| if self.out is not None: | |
| x = self.out(x) | |
| if return_attn: | |
| if cache_kv: | |
| return x, kv_caches, all_attn_maps | |
| else: | |
| return x, all_attn_maps | |
| elif cache_kv: | |
| return x, kv_caches | |
| else: | |
| return x | |
| def insert_eos_token(idx, z_len, eos_id): | |
| """Insert eos_id at position z_len in idx, returning a tensor with length +1.""" | |
| return torch.cat([idx[:, :z_len], | |
| torch.full((idx.shape[0], 1), eos_id, device=idx.device, dtype=idx.dtype), | |
| idx[:, z_len:]], dim=1) | |
| class VQLossDetail: | |
| quant_loss: Tensor | |
| entropy_loss: Tensor | |
| sample_entropy: Tensor | |
| batch_entropy: Tensor | |
| l2norm_z: Tensor | |
| l2norm_code: Tensor | |
| def zero(device): | |
| z = torch.zeros((), device=device) | |
| return VQLossDetail(z, z, z, z, z, z) | |
| def from_tuple(t): | |
| return VQLossDetail(*t) | |
| class EncoderOutput: | |
| quant: Tensor | |
| indices: Tensor | |
| one_hot: Tensor | |
| semantic_vq_loss: VQLossDetail | |
| visual_vq_loss: VQLossDetail | |
| semantic_quant: Tensor | |
| visual_quant: Tensor | |
| semantic_indices: Tensor | |
| visual_indices: Tensor | |
| semantic_one_hot: Tensor = None | |
| class AROutput: | |
| logits: Tensor | |
| semantic_logits: Tensor | |
| visual_logits: Tensor | |
| class Encoder(nn.Module): | |
| def __init__(self, config) -> None: | |
| super().__init__() | |
| self.prologue = config.get("Prologue", False) | |
| self.share_semantic_encoder = config.get("share_semantic_encoder", True) | |
| self.share_semantic_codebook = config.get("share_semantic_codebook", False) | |
| self.z_len = config["z_len"] | |
| self.x_len = config["x_len"] | |
| if self.prologue: | |
| if self.share_semantic_encoder and config["Encoder"].get("context_type") != "concat": | |
| warnings.warn(f"Prologue with shared backbone requires Encoder context_type='concat', got '{config['Encoder'].get('context_type')}'. Forcing to 'concat'.") | |
| config["Encoder"]["context_type"] = "concat" | |
| self.enc = Transformer(**config["Encoder"]) | |
| self.quantizer = PrologueQuantizer(**config["Quantizer"]) | |
| if self.prologue and not self.share_semantic_codebook: | |
| self.semantic_quantizer = PrologueQuantizer(**config["SemanticQuantizer"]) | |
| if self.prologue and not self.share_semantic_encoder: | |
| self.semantic_input_type = config.get("semantic_input_type", "encoder_output") | |
| self.semantic_enc = Transformer(**config["SemanticEncoder"]) | |
| if not self.prologue or self.share_semantic_codebook: | |
| self._forward_impl = self._forward_simple | |
| self._encode_idx_impl = self._encode_idx_simple | |
| elif self.share_semantic_encoder: | |
| self._forward_impl = self._forward_split_codebook | |
| self._encode_idx_impl = self._encode_idx_split_codebook | |
| else: | |
| self._forward_impl = self._forward_separate_enc | |
| self._encode_idx_impl = self._encode_idx_separate_enc | |
| def forward(self, x, labels, training=False) -> EncoderOutput: | |
| return self._forward_impl(x, labels, training) | |
| def encode_idx(self, x: torch.Tensor, labels=None) -> torch.Tensor: | |
| return self._encode_idx_impl(x, labels) | |
| # ---- strategy: simple (no Prologue, or Prologue with shared codebook) ---- | |
| def _forward_simple(self, x, labels, training): | |
| h = self.enc(x, labels) | |
| if self.enc.context_type == 'concat': | |
| h = h[:, -self.enc.max_seq_len:] | |
| quant, idx, one_hot, vqloss = self.quantizer(h, labels, training=training) | |
| vl = VQLossDetail.from_tuple(vqloss) if isinstance(vqloss, tuple) else VQLossDetail(vqloss, *([torch.zeros((), device=quant.device)] * 5)) | |
| return EncoderOutput( | |
| quant=quant, indices=idx, one_hot=one_hot, | |
| semantic_vq_loss=None, visual_vq_loss=vl, | |
| semantic_quant=None, visual_quant=quant, | |
| semantic_indices=None, visual_indices=idx, | |
| semantic_one_hot=None, | |
| ) | |
| def _encode_idx_simple(self, x, labels): | |
| z = self.enc(x, labels) | |
| if self.enc.context_type == 'concat': | |
| z = z[:, -self.enc.max_seq_len:] | |
| return self.quantizer.encode(z, labels) | |
| # ---- strategy: split codebook (Prologue, shared encoder, separate codebooks) ---- | |
| def _forward_split_codebook(self, x, labels, training): | |
| h = self.enc(x, labels) | |
| h_v, h_s = h[:, :self.x_len, :], h[:, self.x_len:, :] | |
| quant_v, idx_v, oh_v, loss_v = self.quantizer(h_v, labels, training=training) | |
| quant_s, idx_s, oh_s, loss_s = self.semantic_quantizer(h_s, labels, training=training) | |
| raw_oh_s = oh_s | |
| idx = torch.cat([idx_s, idx_v], dim=1) | |
| if oh_s is not None and oh_v is not None: | |
| max_cb = max(oh_s.shape[-1], oh_v.shape[-1]) | |
| if oh_s.shape[-1] < max_cb: | |
| oh_s = F.pad(oh_s, (0, max_cb - oh_s.shape[-1])) | |
| if oh_v.shape[-1] < max_cb: | |
| oh_v = F.pad(oh_v, (0, max_cb - oh_v.shape[-1])) | |
| one_hot = torch.cat([oh_s, oh_v], dim=1) | |
| else: | |
| one_hot = None | |
| vl_s = VQLossDetail.from_tuple(loss_s) if isinstance(loss_s, tuple) else VQLossDetail(loss_s, *([torch.zeros((), device=quant_v.device)] * 5)) | |
| vl_v = VQLossDetail.from_tuple(loss_v) if isinstance(loss_v, tuple) else VQLossDetail(loss_v, *([torch.zeros((), device=quant_v.device)] * 5)) | |
| return EncoderOutput( | |
| quant=quant_v, indices=idx, one_hot=one_hot, | |
| semantic_vq_loss=vl_s, visual_vq_loss=vl_v, | |
| semantic_quant=quant_s, visual_quant=quant_v, | |
| semantic_indices=idx_s, visual_indices=idx_v, | |
| semantic_one_hot=raw_oh_s, | |
| ) | |
| def _encode_idx_split_codebook(self, x, labels): | |
| z = self.enc(x, labels) | |
| z_v, z_s = z[:, :self.x_len, :], z[:, self.x_len:, :] | |
| idx_v = self.quantizer.encode(z_v, labels) | |
| idx_s = self.semantic_quantizer.encode(z_s, labels) | |
| return torch.cat([idx_s, idx_v], dim=1) | |
| # ---- strategy: separate encoder (Prologue-Post, independent backbone) ---- | |
| def _forward_separate_enc(self, x, labels, training): | |
| h = self.enc(x, labels) | |
| vq_out = self.quantizer(h, labels, training=training, | |
| return_continuous=(self.semantic_input_type == "pre_quant")) | |
| if self.semantic_input_type == "pre_quant": | |
| quant_v, idx_v, oh_v, loss_v, z_continuous = vq_out | |
| sem_input = z_continuous | |
| else: | |
| quant_v, idx_v, oh_v, loss_v = vq_out | |
| sem_input = h if self.semantic_input_type == "encoder_output" else x | |
| h_s = self.semantic_enc(sem_input, labels) | |
| if self.semantic_enc.context_type == 'concat': | |
| h_s = h_s[:, -self.semantic_enc.max_seq_len:] | |
| quant_s, idx_s, oh_s, loss_s = self.semantic_quantizer(h_s, labels, training=training) | |
| raw_oh_s = oh_s | |
| idx = torch.cat([idx_s, idx_v], dim=1) | |
| if oh_s is not None and oh_v is not None: | |
| max_cb = max(oh_s.shape[-1], oh_v.shape[-1]) | |
| if oh_s.shape[-1] < max_cb: | |
| oh_s = F.pad(oh_s, (0, max_cb - oh_s.shape[-1])) | |
| if oh_v.shape[-1] < max_cb: | |
| oh_v = F.pad(oh_v, (0, max_cb - oh_v.shape[-1])) | |
| one_hot = torch.cat([oh_s, oh_v], dim=1) | |
| else: | |
| one_hot = None | |
| vl_s = VQLossDetail.from_tuple(loss_s) if isinstance(loss_s, tuple) else VQLossDetail(loss_s, *([torch.zeros((), device=quant_v.device)] * 5)) | |
| vl_v = VQLossDetail.from_tuple(loss_v) if isinstance(loss_v, tuple) else VQLossDetail(loss_v, *([torch.zeros((), device=quant_v.device)] * 5)) | |
| return EncoderOutput( | |
| quant=quant_v, indices=idx, one_hot=one_hot, | |
| semantic_vq_loss=vl_s, visual_vq_loss=vl_v, | |
| semantic_quant=quant_s, visual_quant=quant_v, | |
| semantic_indices=idx_s, visual_indices=idx_v, | |
| semantic_one_hot=raw_oh_s, | |
| ) | |
| def _encode_idx_separate_enc(self, x, labels): | |
| h = self.enc(x, labels) | |
| idx_v = self.quantizer.encode(h, labels) | |
| if self.semantic_input_type == "encoder_output": | |
| sem_input = h | |
| elif self.semantic_input_type == "image_patch": | |
| sem_input = x | |
| else: | |
| vq_out = self.quantizer(h, labels, training=False, return_continuous=True) | |
| sem_input = vq_out[4] | |
| h_s = self.semantic_enc(sem_input, labels) | |
| if self.semantic_enc.context_type == 'concat': | |
| h_s = h_s[:, -self.semantic_enc.max_seq_len:] | |
| idx_s = self.semantic_quantizer.encode(h_s, labels) | |
| return torch.cat([idx_s, idx_v], dim=1) | |
| # ---- helper properties & methods ---- | |
| def has_separate_semantic(self): | |
| return self.prologue and not self.share_semantic_encoder | |
| def visual_modules(self): | |
| return ["enc", "quantizer"] | |
| def semantic_modules(self): | |
| return ["semantic_enc", "semantic_quantizer"] if self.has_separate_semantic else [] | |
| def total_token_len(self): | |
| if self.prologue and not self.share_semantic_codebook: | |
| return self.z_len + self.x_len | |
| return self.z_len | |
| def get_visual_codes(self, indices, labels): | |
| if self.prologue and not self.share_semantic_codebook: | |
| visual_ids = indices[:, -self.x_len:] | |
| return self.quantizer.get_codes_w_indices(visual_ids, labels) | |
| return self.quantizer.get_codes_w_indices(indices, labels) | |
| def visual_parameters(self): | |
| return chain(self.enc.parameters(), self.quantizer.parameters()) | |
| def semantic_parameters(self): | |
| return chain(self.semantic_enc.parameters(), self.semantic_quantizer.parameters()) | |
| class Decoder(nn.Module): | |
| def __init__(self, config) -> None: | |
| super().__init__() | |
| self.dec = Transformer(**config["Decoder"]) | |
| self.is_concat = config["Decoder"].get("context_type", "none") == "concat" | |
| self.output_len = config["Decoder"]["max_seq_len"] | |
| def forward(self, quant, labels): | |
| x_hat = self.dec(quant, labels) | |
| if self.is_concat: | |
| x_hat = x_hat[:, -self.output_len:] | |
| return x_hat | |
| class ARModel(nn.Module): | |
| def __init__(self, config) -> None: | |
| super().__init__() | |
| ar_config = config["ARModel"] | |
| self.conditional_injection = ar_config.get("conditional_injection", "llamagen") # dit | llamagen | |
| # ste_ar_embedding: use one_hot @ W_emb for the prologue prefix so grads flow into the encoder. | |
| self.ste_ar_embedding = ar_config.get("ste_ar_embedding", False) | |
| self.num_classes = ar_config.cond_input_dim | |
| print0("conditional_injection: ", self.conditional_injection) | |
| print0("ste_ar_embedding: ", self.ste_ar_embedding) | |
| print0("num_classes: ", self.num_classes) | |
| prologue = config.get("Prologue", False) and not config.get("share_semantic_codebook", False) | |
| self.z_len = int(config.get("z_len", 0)) if prologue else 0 | |
| vis_cb_size = int(config["Quantizer"]["codebook_size"]) | |
| sem_cb_size = int(config["SemanticQuantizer"]["codebook_size"]) if prologue else 0 | |
| self.ar_vocab_size = vis_cb_size + sem_cb_size | |
| self.semantic_offset = vis_cb_size | |
| self.semantic_codebook_size = sem_cb_size | |
| self.visual_codebook_size = vis_cb_size | |
| use_eos = bool(config.get("use_eos", False)) | |
| self.eos_len = 1 if (use_eos and self.z_len > 0) else 0 | |
| if self.eos_len > 0: | |
| self.eos_token_id = self.ar_vocab_size | |
| self.ar_vocab_size += 1 | |
| ar_config['max_seq_len'] = int(ar_config['max_seq_len']) + 1 | |
| if 'rope_hybrid_l_len' in ar_config and int(ar_config['rope_hybrid_l_len']) > 0: | |
| ar_config['rope_hybrid_l_len'] = int(ar_config['rope_hybrid_l_len']) + 1 | |
| print0(f"use_eos: eos_token_id={self.eos_token_id}, max_seq_len→{ar_config['max_seq_len']}") | |
| else: | |
| self.eos_token_id = -1 | |
| print0(f"AR vocab: {self.ar_vocab_size}, semantic_offset: {self.semantic_offset}") | |
| bos_num = self.num_classes if self.conditional_injection == "llamagen" else 1 | |
| self.bos_emb = nn.Embedding(bos_num, ar_config["dim"]) | |
| nn.init.trunc_normal_(self.bos_emb.weight, std=0.02) | |
| self.semantic_emb = nn.Embedding(self.ar_vocab_size, ar_config["dim"]) | |
| nn.init.trunc_normal_(self.semantic_emb.weight, std=0.02) | |
| self.tied_embedding = ar_config.get("tied_embedding", False) | |
| if self.tied_embedding: | |
| ar_config['out_layer'] = False | |
| print0("tied_embedding: True, output layer disabled in Transformer") | |
| ar_config['output_dim'] = self.ar_vocab_size | |
| self.ar_model = Transformer(**ar_config) | |
| self.temperature = ar_config["temperature"] | |
| self.max_length = ar_config["max_seq_len"] | |
| uncond_idx = self.num_classes - 1 | |
| self.register_buffer( | |
| 'uncond_ar_labels', | |
| F.one_hot(torch.tensor([uncond_idx], dtype=torch.long), num_classes=self.num_classes).float() | |
| ) | |
| self.register_buffer('logit_mask', None, persistent=False) | |
| def forward(self, idx, labels, temperature=None, semantic_one_hot=None, return_attn=False) -> AROutput: | |
| bz = idx.shape[0] | |
| if self.conditional_injection == "llamagen": | |
| bos = self.bos_emb(torch.argmax(labels, dim=1)).unsqueeze(1) | |
| labels = self.uncond_ar_labels.expand(bz, -1).to(device=labels.device, dtype=labels.dtype) | |
| else: | |
| bos = self.bos_emb(torch.zeros(bz, device=idx.device, dtype=torch.long)).unsqueeze(1) | |
| # STE: prefix embedding via one_hot @ W_emb so grads reach the encoder. | |
| if self.ste_ar_embedding and semantic_one_hot is not None and self.z_len > 0: | |
| sem_weight = self.semantic_emb.weight[self.semantic_offset:self.semantic_offset + self.semantic_codebook_size] | |
| sem_token_emb = semantic_one_hot @ sem_weight | |
| rest_token_emb = self.semantic_emb(idx[:, self.z_len:-1]) | |
| token_emb = torch.cat([sem_token_emb, rest_token_emb], dim=1) | |
| elif self.ste_ar_embedding and semantic_one_hot is not None: | |
| weight = self.semantic_emb.weight[:self.visual_codebook_size] | |
| token_emb = semantic_one_hot[:, :-1] @ weight | |
| else: | |
| token_emb = self.semantic_emb(idx[:, :-1]) | |
| shift_input = torch.cat([bos, token_emb], dim=1) | |
| result = self.ar_model(shift_input, labels, return_attn=return_attn) | |
| if return_attn: | |
| out, all_attn_maps = result | |
| else: | |
| out = result | |
| logits = F.linear(out, self.semantic_emb.weight) if self.tied_embedding else out | |
| if temperature is not None: | |
| logits = logits / temperature | |
| elif self.temperature is not None: | |
| logits = logits / self.temperature | |
| if self.z_len > 0: | |
| semantic_logits = logits[:, :self.z_len] | |
| visual_logits = logits[:, self.z_len + self.eos_len:] | |
| else: | |
| semantic_logits = logits | |
| visual_logits = logits | |
| ar_output = AROutput(logits=logits, semantic_logits=semantic_logits, visual_logits=visual_logits) | |
| if return_attn: | |
| return ar_output, all_attn_maps | |
| return ar_output | |
| def set_logit_mask(self, mask): | |
| """Install per-step ``logit_mask`` buffer; pads EOS row/slot at ``z_len`` when ``use_eos``.""" | |
| if self.eos_len > 0: | |
| if mask is not None: | |
| mask = F.pad(mask, (0, 1), value=float('-inf')) | |
| eos_row = torch.full((1, self.ar_vocab_size), float('-inf')) | |
| eos_row[0, self.eos_token_id] = 0. | |
| mask = torch.cat([mask[:self.z_len], eos_row, mask[self.z_len:]], dim=0) | |
| else: | |
| mask = torch.zeros(self.max_length, self.ar_vocab_size) | |
| mask[:, self.eos_token_id] = float('-inf') | |
| mask[self.z_len] = float('-inf') | |
| mask[self.z_len, self.eos_token_id] = 0. | |
| self.register_buffer('logit_mask', mask, persistent=False) | |
| def sampling(self, bz, class_label=None, temperature=1.0, topK=None, topP=None, cfg=16.0, cfg_schedule='cosine', cfg_power=2.75, cache_kv=False, | |
| semantic_cfg_schedule=None, semantic_cfg_scale=None, semantic_cfg_power=None, semantic_cfg_start=0.0, | |
| visual_cfg_schedule=None, visual_cfg_scale=None, visual_cfg_power=None, visual_cfg_start=1.0, | |
| semantic_temperature=None): | |
| cfg = 0. if class_label is None else cfg | |
| use_segmented = (semantic_cfg_schedule is not None or visual_cfg_schedule is not None) | |
| use_cfg = cfg > 0. | |
| if use_segmented: | |
| _ss = semantic_cfg_scale if semantic_cfg_scale is not None else cfg | |
| _vs = visual_cfg_scale if visual_cfg_scale is not None else cfg | |
| use_cfg = _ss > 0. or _vs > 0. | |
| uncond_idx = int(self.ar_model.cond_input_dim) - 1 | |
| device = self.bos_emb.weight.device | |
| if self.conditional_injection == "llamagen": | |
| cond_bos = self.bos_emb(torch.argmax(class_label, dim=1)).unsqueeze(1) # [B, 1, D] | |
| uncond_bos = self.bos_emb(torch.full((bz,), uncond_idx, device=device, dtype=torch.long)).unsqueeze(1) | |
| ar_labels = self.uncond_ar_labels.expand(bz, -1).to(device=device) # [B, num_classes] | |
| uncond_labels = self.uncond_ar_labels.expand(bz, -1).to(device=device) | |
| else: | |
| cond_bos = self.bos_emb(torch.zeros(bz, device=device, dtype=torch.long)).unsqueeze(1) | |
| uncond_bos = cond_bos | |
| ar_labels = class_label | |
| uncond_labels = self.uncond_ar_labels.expand(bz, -1).to(device=device) | |
| quant_input = torch.cat([cond_bos, uncond_bos], dim=0) if use_cfg else cond_bos | |
| ar_labels = torch.cat([ar_labels, uncond_labels], dim=0) if use_cfg else ar_labels | |
| quant_output = [] | |
| past_kvs = None | |
| for step in range(self.max_length): | |
| # CFG | |
| if use_cfg: | |
| ar_out = self.ar_model(quant_input, ar_labels, cache_kv=cache_kv, past_kvs=past_kvs) | |
| if cache_kv: | |
| hidden_all, past_kvs = ar_out | |
| else: | |
| hidden_all = ar_out | |
| if self.tied_embedding: | |
| hidden_all = F.linear(hidden_all[:, -1:], self.semantic_emb.weight) | |
| logits_all = hidden_all[:, -1] | |
| logits, uncond_logits = logits_all.chunk(2, dim=0) | |
| is_semantic_step = (self.z_len > 0 and step < self.z_len) | |
| if use_segmented: | |
| if is_semantic_step: | |
| seg_schedule = semantic_cfg_schedule or 'constant' | |
| seg_scale = semantic_cfg_scale if semantic_cfg_scale is not None else cfg | |
| seg_power = semantic_cfg_power if semantic_cfg_power is not None else cfg_power | |
| seg_start = semantic_cfg_start | |
| seg_t = step / self.z_len if self.z_len > 0 else 0.0 | |
| else: | |
| seg_schedule = visual_cfg_schedule or 'constant' | |
| seg_scale = visual_cfg_scale if visual_cfg_scale is not None else cfg | |
| seg_power = visual_cfg_power if visual_cfg_power is not None else cfg_power | |
| seg_start = visual_cfg_start | |
| visual_start = self.z_len | |
| visual_len = self.max_length - visual_start | |
| seg_t = (step - visual_start) / visual_len if visual_len > 0 else 0.0 | |
| if seg_schedule == 'constant': | |
| cfg_scale = seg_scale | |
| elif seg_schedule == 'linear': | |
| cfg_scale = seg_start + (seg_scale - seg_start) * seg_t | |
| elif seg_schedule == 'cosine': | |
| shape = (1 - math.cos( | |
| (seg_t ** seg_power) * math.pi)) * 0.5 | |
| cfg_scale = seg_start + (seg_scale - seg_start) * shape | |
| else: | |
| raise ValueError(f"Invalid segmented cfg schedule: {seg_schedule}") | |
| elif cfg_schedule == 'constant': | |
| cfg_scale = cfg | |
| elif cfg_schedule == 'linear': | |
| cfg_scale = 1.0 * (1-step/self.max_length) + cfg * (step/self.max_length) | |
| elif cfg_schedule == 'cosine': | |
| cfg_scale = (1 - math.cos( | |
| ((step / self.max_length) ** cfg_power) * math.pi)) * 1/2 | |
| cfg_scale = (cfg - 1) * cfg_scale + 1 | |
| else: | |
| raise ValueError(f"Invalid cfg_schedule: {cfg_schedule}") | |
| logits = cfg_scale * logits + (1 - cfg_scale) * uncond_logits | |
| else: | |
| ar_out = self.ar_model(quant_input, ar_labels, cache_kv=cache_kv, past_kvs=past_kvs) | |
| if cache_kv: | |
| hidden, past_kvs = ar_out | |
| else: | |
| hidden = ar_out | |
| if self.tied_embedding: | |
| hidden = F.linear(hidden[:, -1:], self.semantic_emb.weight) | |
| logits = hidden[:, -1] | |
| is_semantic_step_temp = (self.z_len > 0 and step < self.z_len) | |
| t = semantic_temperature if (semantic_temperature is not None and is_semantic_step_temp) else temperature | |
| logits = logits / t | |
| if self.logit_mask is not None: | |
| logits = logits + self.logit_mask[step] | |
| # Top-K filtering | |
| if topK is not None and topK > 0.: | |
| top_logits, top_indices = logits.topk(topK, dim=-1) | |
| logits = torch.full_like(logits, float('-inf')) | |
| logits.scatter_(dim=-1, index=top_indices, src=top_logits) | |
| # Top-P (nucleus) filtering | |
| if topP is not None and 0. < topP < 1.: | |
| sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) | |
| probs_sum = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
| mask = probs_sum > topP | |
| mask[..., 1:] = mask[..., :-1].clone() | |
| mask[..., 0] = False | |
| sorted_logits[mask] = float('-inf') | |
| logits = torch.full_like(logits, float('-inf')) | |
| logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits) | |
| # Sample the next token | |
| with torch.amp.autocast("cuda", enabled=False): | |
| next_idx = torch.multinomial(F.softmax(logits.float(), dim=-1), 1) | |
| next_idx = next_idx.to(dtype=torch.long) | |
| quant_output.append(next_idx) | |
| # Feed the new token back as the next input embedding | |
| next_emb = self.semantic_emb(next_idx) | |
| if use_cfg: | |
| next_emb = torch.cat([next_emb, next_emb], dim=0) # [2B, 1, D] | |
| if not cache_kv: | |
| quant_input = torch.cat((quant_input, next_emb), dim=1) | |
| else: | |
| quant_input = next_emb | |
| quant_output = torch.cat(quant_output, dim=1) | |
| return quant_output | |
| # deterministic Hard-Gumbel Softmax Quantizer | |
| class L2Normalize(nn.Module): | |
| """Wrapper for F.normalize to make it a proper nn.Module""" | |
| def __init__(self, dim=-1): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x, element_wise=False): | |
| if not element_wise: | |
| return F.normalize(x, dim=self.dim) | |
| else: | |
| dtype_one = torch.Tensor([1.0]).to(x.device).to(x.dtype) | |
| return torch.where(x>0, dtype_one, -dtype_one) | |
| class PrologueQuantizer(nn.Module): | |
| """VQ with optional STE variants (``VQ`` / ``LFQ`` / ``ProbQ``) and per-position ``pos_select_mask``.""" | |
| def __init__(self, codebook_size, dim, z_dim, cond_dim, | |
| codebook_init='trunc_normal', | |
| z_norm_type='none', # AdaLN, AdaRMSNorm, FixLN, FixRMSNorm, l2, none | |
| temperature=1.0, | |
| frozen_codebook=False, codebook_proj='none', | |
| ste='ProbQ', | |
| length=256, | |
| vq_beta=0.25, use_quant_loss=False, use_entropy_loss=False, monitor_metrics=True, | |
| **params): | |
| super().__init__() | |
| self.monitor_metrics = monitor_metrics | |
| self.codebook_size = codebook_size # codebook size (K) | |
| self.dim = dim # embedding dimension (D) | |
| self.z_dim = z_dim | |
| self.length = int(length) | |
| self.proj_in = Linear(dim, z_dim) | |
| elementwise_affine = True | |
| cond_based_affine = False | |
| self.cond_proj = None | |
| if 'Ada' in z_norm_type: | |
| self.cond_proj = nn.Sequential(Linear(1001, cond_dim), AdaRMSNorm(cond_dim, cond_based_affine=False), nn.GELU(approximate='tanh')) | |
| elementwise_affine = True | |
| cond_based_affine = True | |
| if 'Fix' in z_norm_type: | |
| cond_based_affine = False | |
| elementwise_affine = False | |
| self.z_norm_type = z_norm_type | |
| self.z_norm = nn.Identity() | |
| if 'l2' in z_norm_type: | |
| self.z_norm = L2Normalize(dim=-1) | |
| elif 'RMSNorm' in z_norm_type: | |
| self.z_norm = AdaRMSNorm(z_dim, cond_dim, elementwise_affine=elementwise_affine, cond_based_affine=cond_based_affine) | |
| elif 'LN' in z_norm_type: | |
| self.z_norm = AdaLN(z_dim, cond_dim, elementwise_affine=elementwise_affine, cond_based_affine=cond_based_affine) | |
| if 'LFQ' in z_norm_type: | |
| self.z_norm = L2Normalize(dim=-1, element_wise=True) | |
| self.codebook_init = codebook_init | |
| if self.codebook_init == 'LFQ': | |
| self.register_buffer("indices_map", 2**torch.arange(z_dim).view(1, 1, -1)) # B L D | |
| self.register_buffer("embedding", self.indices_to_emb(torch.arange(codebook_size).view(-1, 1, 1)).view(-1, z_dim)) # N D | |
| else: | |
| self.embedding = nn.Embedding(codebook_size, z_dim) | |
| if codebook_init == 'trunc_normal': | |
| torch.nn.init.trunc_normal_(self.embedding.weight, std=0.02) | |
| elif codebook_init == 'uniform': | |
| self.embedding.weight.data.uniform_(-np.sqrt(1 / codebook_size), np.sqrt(1 / codebook_size)) | |
| elif codebook_init == 'simvq': | |
| self.embedding.weight.data.normal_(mean=0, std=self.z_dim**-0.5) | |
| if frozen_codebook: | |
| self.embedding.requires_grad_(False) | |
| self.codebook_proj = nn.Identity() | |
| if codebook_proj == 'linear': | |
| self.codebook_proj = Linear(z_dim, z_dim) | |
| elif codebook_proj == 'norm_linear': | |
| self.codebook_proj = nn.Sequential(AdaRMSNorm(z_dim, elementwise_affine=True, cond_based_affine=False), Linear(z_dim, z_dim)) | |
| elif codebook_proj == 'linear_norm': | |
| self.codebook_proj = nn.Sequential(Linear(z_dim, z_dim), AdaRMSNorm(z_dim, elementwise_affine=True, cond_based_affine=False)) | |
| elif codebook_proj == 'mlp': | |
| self.codebook_proj = nn.Sequential(Linear(z_dim, dim), nn.GELU(approximate='tanh'), Linear(dim, z_dim)) | |
| elif codebook_proj == 'mlp_with_norm': | |
| self.codebook_proj = nn.Sequential(Linear(z_dim, dim), AdaRMSNorm(dim, elementwise_affine=True, cond_based_affine=False), nn.GELU(approximate='tanh'), Linear(dim, z_dim)) | |
| elif codebook_proj == 'l2': | |
| self.codebook_proj = L2Normalize(dim=-1) | |
| self.ste = ste | |
| self.vq_beta = vq_beta | |
| self.use_quant_loss = use_quant_loss | |
| self.use_entropy_loss = use_entropy_loss | |
| self.temperature = temperature | |
| print0(f"z_norm_type: {z_norm_type}") | |
| print0(f"codebook_proj: {codebook_proj}") | |
| print0(f"codebook_init: {codebook_init}") | |
| print0(f"frozen_codebook: {frozen_codebook}") | |
| print0(f"temperature: {temperature}") | |
| print0(f"ste: {ste}") | |
| # All-zero pos_select_mask; AR sampling composes a global mask (see utils.build_ar_logit_mask). | |
| mask_bool = torch.ones(self.length, codebook_size, dtype=torch.bool) | |
| self.register_buffer('pos_select_mask', torch.where(mask_bool, 0., float('-inf'))) | |
| def indices_to_emb(self, indices): | |
| return ((indices.int() & self.indices_map) != 0).float() * 2. - 1. | |
| def codebook(self): | |
| if self.codebook_init == 'LFQ': | |
| return self.embedding | |
| else: | |
| return self.embedding.weight | |
| def _get_pos_mask(self, L): | |
| mask = self.pos_select_mask | |
| return mask[:L] if L < mask.shape[0] else mask | |
| def forward(self, x, labels=None, training=False, log_usage=False, indices=None, return_continuous=False): | |
| B, L, D = x.shape | |
| x = x.float() | |
| labels = self.cond_proj(labels.float()) if self.cond_proj is not None and labels is not None else labels | |
| z = self.proj_in(x) | |
| z_normed = self.z_norm(z, labels) if 'Ada' in self.z_norm_type else self.z_norm(z) | |
| codebook = self.codebook | |
| codebook_normed = self.codebook_proj(codebook) | |
| logits = torch.einsum('bld,nd->bln', z_normed, codebook_normed) | |
| prob = F.softmax(logits / self.temperature, dim=-1) | |
| pos_mask = self._get_pos_mask(L) | |
| indices = torch.argmax(prob + pos_mask, dim=-1) if indices is None else indices | |
| one_hot_ng = F.one_hot(indices, self.codebook_size).view(B, L, -1).to(z.device).to(z.dtype) | |
| one_hot = prob + (one_hot_ng - prob).detach() | |
| if self.ste == 'VQ': | |
| quant = codebook_normed[indices] | |
| quant_ste = z_normed + (quant - z_normed).detach() | |
| elif self.ste == 'LFQ': | |
| quant = codebook[indices] | |
| quant_ste = z + (quant - z).detach() | |
| elif self.ste == 'ProbQ': | |
| quant_ste = torch.einsum('bln,nd->bld', one_hot, codebook_normed) | |
| quant_loss = torch.tensor(0., device=z.device, dtype=z.dtype) | |
| if self.monitor_metrics or (training and self.use_quant_loss): | |
| if self.ste == 'VQ': | |
| quant_loss = self.vq_beta * (quant.detach() - z_normed).pow(2).mean() + (quant - z_normed.detach()).pow(2).mean() | |
| elif self.ste == 'LFQ': | |
| quant_loss = self.vq_beta * torch.mean((quant.detach() - z)**2) + torch.mean((quant - z.detach())**2) | |
| elif self.ste == 'ProbQ': | |
| quant_ng = torch.einsum('bln,nd->bld', one_hot_ng, codebook_normed) | |
| quant_loss = torch.mean((quant_ste - z_normed)**2) + self.vq_beta * torch.mean((quant_ng.detach() - z_normed)**2) + torch.mean((quant_ng - z_normed.detach())**2) | |
| quant_loss = quant_loss.detach() if not (training and self.use_quant_loss) else quant_loss | |
| sample_entropy = torch.tensor(0., device=z.device, dtype=z.dtype) | |
| batch_entropy = torch.tensor(0., device=z.device, dtype=z.dtype) | |
| entropy_loss = torch.tensor(0., device=z.device, dtype=z.dtype) | |
| if self.monitor_metrics or (training and self.use_entropy_loss): | |
| masked_logits = logits + pos_mask | |
| sample_entropy, batch_entropy, entropy_loss = compute_entropy_loss(logits=masked_logits.reshape(-1, self.codebook_size)) | |
| sample_entropy = sample_entropy.detach() if not (training and self.use_entropy_loss) else sample_entropy | |
| batch_entropy = batch_entropy.detach() if not (training and self.use_entropy_loss) else batch_entropy | |
| entropy_loss = entropy_loss.detach() if not (training and self.use_entropy_loss) else entropy_loss | |
| l2norm_code = torch.tensor(0., device=z.device, dtype=z.dtype) | |
| l2norm_z = torch.tensor(0., device=z.device, dtype=z.dtype) | |
| if self.monitor_metrics: | |
| l2norm_code = torch.norm(codebook_normed, p=2, dim=-1).mean().detach() | |
| l2norm_z = torch.norm(z_normed, p=2, dim=-1).mean().detach() | |
| loss_tuple = (quant_loss, entropy_loss, sample_entropy, batch_entropy, l2norm_z, l2norm_code) | |
| if return_continuous: | |
| return quant_ste, indices, one_hot, loss_tuple, z_normed | |
| return quant_ste, indices, one_hot, loss_tuple | |
| def encode(self, x: torch.Tensor, labels=None): | |
| B, L, D = x.shape | |
| x = x.float() | |
| labels = self.cond_proj(labels.float()) if self.cond_proj is not None and labels is not None else labels | |
| z = self.proj_in(x) | |
| z_normed = self.z_norm(z, labels) if 'Ada' in self.z_norm_type else self.z_norm(z) | |
| codebook = self.codebook | |
| codebook_normed = self.codebook_proj(codebook) | |
| logits = torch.einsum('bld,nd->bln', z_normed, codebook_normed) | |
| indices = torch.argmax(logits + self._get_pos_mask(L), dim=-1) | |
| return indices | |
| def get_codes_w_indices(self, indices, labels=None, **params): | |
| if self.codebook_init == 'LFQ': | |
| codes = self.indices_to_emb(indices) | |
| else: | |
| codes = self.embedding(indices) | |
| codes = self.codebook_proj(codes) | |
| return codes | |
| def compute_entropy_loss( | |
| logits, | |
| temperature=0.01, | |
| sample_minimization_weight=1.0, | |
| batch_maximization_weight=1.0, | |
| eps=1e-5, | |
| ): | |
| """Entropy loss on logits (affinities over the last dim); from MAGVIT (Yu et al., 2024).""" | |
| with torch.amp.autocast("cuda", enabled=False): | |
| probs = F.softmax(logits.float() / temperature, -1) | |
| log_probs = F.log_softmax(logits.float() / temperature + eps, -1) | |
| probs = probs.to(logits.dtype) | |
| log_probs = log_probs.to(logits.dtype) | |
| avg_probs = reduce(probs, "... D -> D", "mean") | |
| avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) | |
| sample_entropy = -torch.sum(torch.nan_to_num(probs * log_probs, nan=0.0), -1) | |
| sample_entropy = torch.mean(sample_entropy) | |
| loss = (sample_minimization_weight * sample_entropy) - ( | |
| batch_maximization_weight * avg_entropy | |
| ) | |
| return sample_entropy, avg_entropy, loss | |