prologue-demo / models.py
Bowen Zheng
init
500ee30
import math
from math import pi
from dataclasses import dataclass
from functools import partial
import torch
import torch.nn as nn
from torch import einsum, broadcast_tensors, Tensor
import torch.nn.functional as F
from torch.nn import Module
from torch.amp import autocast
from torch import nn
from einops import rearrange, repeat, reduce
import einops
import numpy as np
import os
import copy
import warnings
from utils import print0
from itertools import chain
# Well Inited Linear
def weight_init(shape, mode, fan_in, fan_out):
if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape)
if mode == 'default': return np.sqrt(1 / fan_in) * (torch.rand(*shape) * 2 - 1) # nn.Linear default
if mode == 'trunc_normal': return torch.nn.init.trunc_normal_(torch.empty(*shape), std=0.02)
if mode == 'uniform': return torch.rand() * 2 - 1
raise ValueError(f'Invalid init mode "{mode}"')
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, init_mode='trunc_normal', init_weight=1, init_bias=0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
def forward(self, x):
x = x @ self.weight.to(x.dtype).t()
if self.bias is not None:
x = x.add_(self.bias.to(x.dtype))
return x
class AdaRMSNorm(nn.Module):
def __init__(self,dim, cond_dim=None, elementwise_affine=True, cond_based_affine=True, centering=False, eps=1e-6):
super(AdaRMSNorm, self).__init__()
self.dim = dim
self.eps = eps
self.cond_based_affine = cond_based_affine
self.elementwise_affine = elementwise_affine
self.scale = dim ** 0.5
assert not(cond_dim is None and cond_based_affine), 'cond_dim must be provided if cond_based_affine is True'
if elementwise_affine:
if cond_based_affine and cond_dim is not None:
self.affine = Linear(cond_dim, dim, init_weight=1e-5)
else:
self.weight = nn.Parameter(torch.zeros(self.dim))
else:
self.register_parameter("weight", None)
def forward(self, x, cond_emb=None):
with torch.amp.autocast('cuda',enabled=False):
output = F.normalize(x.float(), dim=(-1)) * self.scale
if self.elementwise_affine:
weight = self.affine(cond_emb).unsqueeze(1) if self.cond_based_affine and cond_emb is not None else self.weight
output = output.mul(1. + weight.float())
return output.type_as(x)
class AdaLN(nn.Module):
def __init__(self,dim, cond_dim=None, elementwise_affine=True, cond_based_affine=True, bias=True, eps=1e-6):
super(AdaLN, self).__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine and not cond_based_affine, eps=eps)
self.cond_based_affine = cond_based_affine
self.bias = bias
assert not(cond_dim is None and cond_based_affine), 'cond_dim must be provided if cond_based_affine is True'
self.affine = Linear(cond_dim, 2 * dim if bias else dim, init_weight=1e-5) if cond_based_affine and cond_dim is not None else None
def forward(self, x, cond_emb=None):
x = self.norm(x)
if self.cond_based_affine:
if self.bias:
shift, scale = self.affine(cond_emb).unsqueeze(1).chunk(2, dim=-1)
x = x.mul(1. + scale).add_(shift)
else:
scale = self.affine(cond_emb).unsqueeze(1)
x = x.mul(1. + scale)
return x
class FluxRopeEMB(nn.Module):
def __init__(self, theta, axes_dim: list[int]):
super().__init__()
# theta can be a scalar (applied to all axes) or a list (per-axis theta)
if isinstance(theta, (int, float)):
self.theta = [theta] * len(axes_dim)
else:
assert len(theta) == len(axes_dim), \
f"rope theta list length {len(theta)} must match axes_dim length {len(axes_dim)}"
self.theta = list(theta)
self.axes_dim = axes_dim
@autocast('cuda',enabled=False)
def forward(self, ids: Tensor) -> Tensor:
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta[i]) for i in range(len(self.axes_dim))],
dim=-3,
)
return emb.unsqueeze(1)
@autocast('cuda',enabled=False)
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
# Accept both (n,) and (b, n) positions.
pos=pos.float()
if pos.ndim == 1:
pos = pos.unsqueeze(0)
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(x, freqs_cis):
"""Apply RoPE to ``x`` ``[B, H, L, D]`` using ``freqs_cis`` ``[*, 1, S, D/2, 2, 2]``."""
if freqs_cis is None:
raise ValueError("freqs_cis is None but RoPE is enabled. Did you forget to pass q_pe/k_pe?")
b, h, l, d = x.shape
if d % 2 != 0:
raise ValueError(f"RoPE requires last dim even, got D={d}")
x_dtype = x.dtype
x_float = x.float().reshape(b, h, l, d // 2, 2)
freqs = freqs_cis.to(device=x.device, dtype=x_float.dtype)
# Matrix-vector multiply using columns: y = M[...,0]*x0 + M[...,1]*x1
x0 = x_float[..., 0]
x1 = x_float[..., 1]
x_out = freqs[..., 0] * x0.unsqueeze(-1) + freqs[..., 1] * x1.unsqueeze(-1) # (..., 2)
return x_out.reshape(b, h, l, d).to(dtype=x_dtype)
# DropPath & FFN
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
if drop_prob == 0. or not training: return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module): # taken from timm
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'(drop_prob=...)'
class FFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., ffn_type='geglu', out_value=1.0, ffn_bias=True):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.ffn_type = ffn_type
if self.ffn_type == 'geglu':
self.fc1 = Linear(in_features, 2*hidden_features, bias=ffn_bias)
self.act = nn.GELU(approximate='tanh')
self.fc2 = Linear(hidden_features, out_features, bias=ffn_bias, init_weight=out_value)
elif self.ffn_type == 'ffn':
self.fc1 = Linear(in_features, hidden_features, bias=ffn_bias)
self.act = nn.GELU(approximate='tanh')
self.fc2 = Linear(hidden_features, out_features, bias=ffn_bias, init_weight=out_value)
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
def forward(self, x):
if self.ffn_type == 'geglu':
gate, value = self.fc1(x).chunk(2, dim=-1)
gated = self.act(value) * gate
return self.drop(self.fc2(gated))
elif self.ffn_type == 'ffn':
return self.drop(self.fc2(self.act(self.fc1(x))))
class Attention(nn.Module):
def __init__(
self, block_idx, embed_dim=768, num_heads=12,
attn_drop=0., proj_drop=0., attn_norm=False,
rope=False, out_value=1.0, attn_out_bias=True
):
super().__init__()
assert embed_dim % num_heads == 0
self.block_idx = block_idx
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.attn_norm = attn_norm
self.rope = rope
self.q_norm = AdaRMSNorm(self.head_dim,cond_based_affine=False) if self.attn_norm else None
self.k_norm = AdaRMSNorm(self.head_dim,cond_based_affine=False) if self.attn_norm else None
self.to_q = Linear(embed_dim, embed_dim, bias=False)
self.to_kv = Linear(embed_dim, embed_dim * 2, bias=False)
self.proj = Linear(embed_dim, embed_dim, bias=attn_out_bias, init_weight=out_value)
self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
self.attn_drop = attn_drop
def forward(self, x, context_emb=None, causal=False, attn_bias=None, cache_kv=False, past_kvs=None,pe=None,ctx_pe=None, return_attn=False):
B, L, C = x.shape
q = self.to_q(x)
q = einops.rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads)
kv = self.to_kv(x) if context_emb is None else self.to_kv(context_emb)
k, v = einops.rearrange(kv, 'b l (k h d) -> k b h l d', k=2, h=self.num_heads)
q = self.q_norm(q) if self.q_norm is not None else q
k = self.k_norm(k) if self.k_norm is not None else k
if self.rope:
q = apply_rope(q, pe)
if context_emb is not None and ctx_pe is not None:
k = apply_rope(k, ctx_pe)
else:
k = apply_rope(k, pe)
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(1).expand(B, self.num_heads, L, L)
if past_kvs is not None:
past_keys, past_values = past_kvs
k = torch.cat((past_keys, k), dim=2)
v = torch.cat((past_values, v), dim=2)
if return_attn:
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
if causal and past_kvs is None and attn_bias is None:
L_q, L_k = q.size(2), k.size(2)
causal_mask = torch.triu(torch.ones(L_q, L_k, device=q.device, dtype=torch.bool), diagonal=1)
attn_weights.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
if attn_bias is not None:
attn_weights = attn_weights + attn_bias
attn_weights = F.softmax(attn_weights, dim=-1)
oup = torch.matmul(attn_weights, v).transpose(1, 2).reshape(B, L, -1)
else:
attn_weights = None
dropout_p = self.attn_drop if self.training else 0.0
oup = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=causal and past_kvs is None and attn_bias is None,
dropout_p=dropout_p,
attn_mask=attn_bias
).transpose(1, 2).reshape(B, L, -1)
out = self.proj_drop(self.proj(oup))
if return_attn:
return (out, (k, v), attn_weights) if cache_kv else (out, attn_weights)
else:
return (out, (k, v)) if cache_kv else out
class TransformerLayer(nn.Module):
def __init__(
self, block_idx, embed_dim, cond_dim,
num_heads, mlp_ratio=8/3, drop=0., attn_drop=0., drop_path=0., attn_norm=False, cross_attn=False, rope=False,
ffn_type='geglu', norm_layer=AdaRMSNorm, cond_based_affine=True, ffn_bias=True, attn_out_bias=True
):
super(TransformerLayer, self).__init__()
self.block_idx = block_idx
self.cond_based_affine = cond_based_affine
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
out_value = 1.0 if cond_based_affine else 1e-5
self.adaln_1 = norm_layer(embed_dim, cond_dim, cond_based_affine=cond_based_affine)
self.attn1 = Attention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_norm=attn_norm, rope=rope, out_value=out_value, attn_out_bias=attn_out_bias)
self.gate1 = Linear(cond_dim, embed_dim, bias=ffn_bias, init_weight=1e-5) if cond_based_affine else None
if cross_attn:
self.has_cross_attn = True
self.adaln_2 = norm_layer(embed_dim, cond_dim, cond_based_affine=cond_based_affine)
self.adaln_ctx = norm_layer(embed_dim, cond_dim, cond_based_affine=cond_based_affine)
self.attn2 = Attention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_norm=attn_norm, rope=rope, out_value=out_value, attn_out_bias=attn_out_bias)
self.gate2 = Linear(cond_dim, embed_dim, bias=ffn_bias, init_weight=1e-5) if cond_based_affine else None
else:
self.has_cross_attn = False
self.adaln_mlp = norm_layer(embed_dim, cond_dim, cond_based_affine=cond_based_affine)
mlp_ratio = eval(mlp_ratio) if isinstance(mlp_ratio, str) else mlp_ratio
self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, ffn_type=ffn_type, out_value=out_value, ffn_bias=ffn_bias)
self.gate_mlp = Linear(cond_dim, embed_dim, bias=ffn_bias, init_weight=1e-5) if cond_based_affine else None
def forward(self, x,context_emb=None, cond_emb=None, causal=False, attn_mask=None, cache_kv=False, past_kvs=None,pe=None,ctx_pe=None, return_attn=False):
self_attn_bias = attn_mask if (attn_mask is not None and not self.has_cross_attn) else None
self_causal = causal if self_attn_bias is None else False
attn_out1 = self.attn1(self.adaln_1(x, cond_emb), context_emb=None,causal=self_causal, attn_bias=self_attn_bias, cache_kv=cache_kv, past_kvs=past_kvs,pe=pe,ctx_pe=None, return_attn=return_attn)
if return_attn:
if cache_kv:
attn_out1, new_kvs, layer_attn_weights = attn_out1
else:
attn_out1, layer_attn_weights = attn_out1
elif cache_kv:
attn_out1, new_kvs = attn_out1[0], attn_out1[1]
gate1 = self.gate1(cond_emb).unsqueeze(1) if self.gate1 is not None else 1.0
x = x + gate1 * self.drop_path(attn_out1)
if context_emb is not None and self.has_cross_attn:
gate2 = self.gate2(cond_emb).unsqueeze(1) if self.gate2 is not None else 1.0
x = x + gate2 * self.drop_path(self.attn2(self.adaln_2(x, cond_emb), context_emb=self.adaln_ctx(context_emb,cond_emb), causal=False, attn_bias=attn_mask,pe=pe,ctx_pe=ctx_pe))
gate_mlp = self.gate_mlp(cond_emb).unsqueeze(1) if self.gate_mlp is not None else 1.0
x = x + gate_mlp * self.drop_path(self.ffn(self.adaln_mlp(x, cond_emb)))
if return_attn:
return (x, new_kvs, layer_attn_weights) if cache_kv else (x, layer_attn_weights)
elif cache_kv:
return x, new_kvs
else:
return x
class Transformer(nn.Module):
def __init__(
self,
layer_num,
input_dim, dim, output_dim, max_seq_len, heads,
cond_input_dim, cond_dim,
context_type='none', ctx_input_dim=None, ctx_max_seq_len=None,
emb_dropout=0., drop=0., attn_drop=0.,
rope=False, abs_pos=False,
attn_norm=False, causal=False,
zero_out=False, noise_query = False,
ffn_type='geglu', # ffn or geglu
norm_layer='RMSNorm', # LayerNorm or RMSNorm
mlp_ratio=8/3,
rope_theta=2000,
rope_ctx_theta=10000,
rope_axes_dim=[32,32,32],
rope_input_type='l',
rope_hybrid_l_len=0,
rope_ctx_type='none',
out_norm=True,
out_layer=True,
out_act=False,
input_bias=True,
input_layer=True,
use_label=True,
ffn_bias=True,
out_bias=True,
attn_out_bias=True,
bos_zero_rope=False,
**params
):
super(Transformer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.noise_query = noise_query
self.cond_input_dim = cond_input_dim
self.dim = dim
self.causal = causal
self.bos_zero_rope = bos_zero_rope
if bos_zero_rope:
print0("[Transformer] bos_zero_rope=True: BOS token will use zero RoPE position (pos=0)")
self.concat_causal = (context_type == 'concat' and causal)
if self.concat_causal:
self.causal = False
_ctx_len = ctx_max_seq_len
_total = max_seq_len + _ctx_len
_mask = torch.zeros(1, _total, _total)
_mask[:, :_ctx_len, _ctx_len:] = float('-inf')
_mask[:, _ctx_len:, _ctx_len:].masked_fill_(
torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1),
float('-inf'))
self.register_buffer('concat_causal_mask', _mask)
self.cross_attn = context_type=='cross_attn'
print0(norm_layer)
norm_layer = AdaRMSNorm if norm_layer=="RMSNorm" else AdaLN
self.use_label = use_label
self.layers = nn.ModuleList([
TransformerLayer(block_idx=i, embed_dim=dim,
cond_dim=cond_dim, num_heads=heads, mlp_ratio=mlp_ratio,
drop=drop, attn_drop=attn_drop, attn_norm=attn_norm, rope=rope, cross_attn=self.cross_attn,
ffn_type=ffn_type, norm_layer=norm_layer, cond_based_affine=self.use_label, ffn_bias=ffn_bias, attn_out_bias=attn_out_bias)
for i in range(layer_num)
])
out_weight = 1e-5 if zero_out else 1
scale = np.sqrt(dim)
self.rope_input_type = rope_input_type
self.rope_hybrid_l_len = rope_hybrid_l_len
self.rope_ctx_type = rope_ctx_type
self.out_norm = norm_layer(dim, cond_dim, cond_based_affine=self.use_label) if out_norm else None
self.out_activate = nn.GELU(approximate='tanh') if out_act else None
self.out = Linear(dim, output_dim, bias=out_bias, init_weight=out_weight) if out_layer else None
if self.use_label:
self.cond_proj = nn.Sequential(Linear(cond_input_dim, cond_dim), norm_layer(cond_dim, cond_based_affine=False), nn.GELU(approximate='tanh'))
else:
self.cond_proj = None
self.max_seq_len = max_seq_len
self.context_type = context_type # "concat, cross_attn, none"
if abs_pos:
if context_type == 'none':
self.abs_pos = nn.Embedding(max_seq_len, dim)
nn.init.trunc_normal_(self.abs_pos.weight, std=0.02)
elif context_type == 'concat':
self.abs_pos = nn.Embedding(max_seq_len + ctx_max_seq_len, dim)
nn.init.trunc_normal_(self.abs_pos.weight, std=0.02)
elif context_type == 'cross_attn':
self.abs_pos = nn.Embedding(max_seq_len, dim)
self.ctx_pos = nn.Embedding(ctx_max_seq_len, dim)
nn.init.trunc_normal_(self.abs_pos.weight, std=0.02)
nn.init.trunc_normal_(self.ctx_pos.weight, std=0.02)
else:
raise ValueError(f"Invalid context_type: {context_type}")
else:
self.abs_pos = None
if rope:
self.rope_emb = FluxRopeEMB(theta=rope_theta, axes_dim=rope_axes_dim)
if context_type != 'none':
self.ctx_rope_emb = FluxRopeEMB(theta=rope_ctx_theta, axes_dim=rope_axes_dim)
else:
self.rope_emb = None
self.ctx_rope_emb = None
self.input_layer = Linear(input_dim, dim, bias=input_bias) if input_layer else nn.Identity()
if context_type != 'none':
self.learnable_query = nn.Parameter(torch.empty(1, max_seq_len, dim))
nn.init.trunc_normal_(self.learnable_query, std=0.02)
self.emb_dropout = nn.Dropout(p=emb_dropout) if emb_dropout > 0 else nn.Identity()
def prepare_pos_ids(self, x_len, z_len, past_len,device):
if self.context_type == 'none':
if self.rope_input_type == 'hw':
w = int(np.sqrt(self.max_seq_len))
idx = torch.arange(past_len, past_len + x_len, device=device)
x_ids = ctx_ids = torch.stack([
(idx // w).float(),
(idx % w).float(),
torch.zeros(x_len, device=device, dtype=torch.float32),
], dim=1)
elif self.rope_input_type == 'l':
l = x_len
x_coords = {
"l": torch.arange(past_len, past_len + l, device=device, dtype=torch.float32),
"h": torch.arange(1, device=device, dtype=torch.float32),
"w": torch.arange(1, device=device, dtype=torch.float32),
}
x_ids = ctx_ids = torch.cartesian_prod( x_coords["l"], x_coords["h"], x_coords["w"])
elif self.rope_input_type == 'hybrid':
l_len = self.rope_hybrid_l_len
hw_total = self.max_seq_len - l_len
h = w = int(np.sqrt(hw_total))
end_pos = past_len + x_len
all_ids = []
sem_start = past_len
sem_end = min(end_pos, l_len)
if sem_start < sem_end:
n = sem_end - sem_start
all_ids.append(torch.stack([
torch.arange(sem_start, sem_end, device=device, dtype=torch.float32),
torch.zeros(n, device=device, dtype=torch.float32),
torch.zeros(n, device=device, dtype=torch.float32),
], dim=1))
vis_start = max(past_len, l_len)
if vis_start < end_pos:
n = end_pos - vis_start
offset = vis_start - l_len
idx = torch.arange(offset, offset + n, device=device)
all_ids.append(torch.stack([
torch.zeros(n, device=device, dtype=torch.float32),
(idx // w).float(),
(idx % w).float(),
], dim=1))
x_ids = ctx_ids = torch.cat(all_ids, dim=0)
elif self.context_type == 'concat':
if self.rope_input_type == 'hw':
h = w = int(np.sqrt(z_len))
x_coords = {
"h": torch.arange(h, device=device, dtype=torch.float32),
"w": torch.arange(w, device=device, dtype=torch.float32),
"l": torch.arange(1, device=device, dtype=torch.float32),
}
elif self.rope_input_type == 'l':
l = z_len
x_coords = {
"h": torch.arange(1, device=device, dtype=torch.float32),
"w": torch.arange(1, device=device, dtype=torch.float32),
"l": torch.arange(l, device=device, dtype=torch.float32),
}
if self.rope_ctx_type == 'hw':
h = w = int(np.sqrt(x_len))
ctx_coords = {
"h": torch.arange(h, device=device, dtype=torch.float32),
"w": torch.arange(w, device=device, dtype=torch.float32),
"l": torch.arange(1, device=device, dtype=torch.float32),
}
elif self.rope_ctx_type == 'l':
l = x_len
ctx_coords = {
"h": torch.arange(1, device=device, dtype=torch.float32),
"w": torch.arange(1, device=device, dtype=torch.float32),
"l": torch.arange(l, device=device, dtype=torch.float32),
}
x_ids = torch.cartesian_prod(x_coords["h"], x_coords["w"], x_coords["l"])
ctx_ids = torch.cartesian_prod(ctx_coords["h"], ctx_coords["w"], ctx_coords["l"])
elif self.context_type == 'cross_attn':
if self.rope_input_type == 'hw':
h = w = int(np.sqrt(z_len))
x_coords = {
"h": torch.arange(h, device=device, dtype=torch.float32),
"w": torch.arange(w, device=device, dtype=torch.float32),
"l": torch.arange(1, device=device, dtype=torch.float32),
}
elif self.rope_input_type == 'l':
l = z_len
x_coords = {
"h": torch.arange(1, device=device, dtype=torch.float32),
"w": torch.arange(1, device=device, dtype=torch.float32),
"l": torch.arange(l, device=device, dtype=torch.float32),
}
if self.rope_ctx_type == 'hw':
h = w = int(np.sqrt(x_len))
ctx_coords = {
"h": torch.arange(h, device=device, dtype=torch.float32),
"w": torch.arange(w, device=device, dtype=torch.float32),
"l": torch.arange(1, device=device, dtype=torch.float32),
}
elif self.rope_ctx_type == 'l':
l = x_len
ctx_coords = {
"h": torch.arange(1, device=device, dtype=torch.float32),
"w": torch.arange(1, device=device, dtype=torch.float32),
"l": torch.arange(l, device=device, dtype=torch.float32),
}
x_ids = torch.cartesian_prod(x_coords["h"], x_coords["w"], x_coords["l"])
ctx_ids = torch.cartesian_prod(ctx_coords["h"], ctx_coords["w"], ctx_coords["l"])
return x_ids, ctx_ids
def forward(self, x, condition, attn_mask=None, cache_kv=False, past_kvs=None, return_attn=False):
x_len = x.size(1)
z_len = self.max_seq_len if self.context_type != 'none' else 0
if self.context_type == 'none':
x, context_emb = self.input_layer(x), None
elif self.context_type == 'concat':
learnable_query = self.learnable_query.expand(x.size(0), -1, -1).to(x.device)
x = self.input_layer(x)
x = torch.cat((x, learnable_query), dim=1)
x, context_emb = x, None
else:
learnable_query = self.learnable_query.expand(x.size(0), -1, -1).to(x.device)
x, context_emb = learnable_query, self.input_layer(x)
x = self.emb_dropout(x)
cond_emb = self.cond_proj[0](condition) if (self.cond_proj is not None and condition is not None) else None
cond_emb = self.emb_dropout(cond_emb) if cond_emb is not None else None
cond_emb = self.cond_proj[1](cond_emb) if cond_emb is not None else None
cond_emb = self.cond_proj[2](cond_emb) if cond_emb is not None else None
if self.abs_pos is not None:
if self.context_type == 'none':
pos_ids = torch.arange(x_len, device=x.device)
if past_kvs is not None:
pos_ids = pos_ids + past_kvs[0][0].shape[2]
x = x + self.abs_pos(pos_ids)
elif self.context_type == 'concat':
pos_ids = torch.arange(x_len+z_len, device=x.device)
x = x + self.abs_pos(pos_ids)
elif self.context_type == 'cross_attn':
pos_ids = torch.arange(x_len, device=x.device)
x = x + self.abs_pos(pos_ids)
if context_emb is not None and self.ctx_pos is not None:
context_emb = context_emb + self.ctx_pos(torch.arange(z_len, device=context_emb.device))
pe = ctx_pe = None
if self.rope_emb is not None:
past_len = 0
if past_kvs is not None:
past_len = past_kvs[0][0].shape[2]
x_ids, ctx_ids = self.prepare_pos_ids(x_len, z_len, past_len, x.device)
pe = self.rope_emb(x_ids)
if self.context_type != 'none':
ctx_pe = self.ctx_rope_emb(ctx_ids)
if self.context_type == 'concat':
pe = torch.cat([ctx_pe, pe], dim=2)
ctx_pe = None
if self.bos_zero_rope and past_len == 0:
pe = pe.clone()
pe[:, :, 0, :, 0, 0] = 0 # cos = 0
pe[:, :, 0, :, 0, 1] = 0 # -sin = 0
pe[:, :, 0, :, 1, 0] = 0 # sin = 0
pe[:, :, 0, :, 1, 1] = 0 # cos = 0
attn_mask = self.concat_causal_mask if self.concat_causal else attn_mask
kv_caches = []
all_attn_maps = [] if return_attn else None
for i,layer in enumerate(self.layers):
layer_output = layer(x, context_emb, cond_emb, self.causal, attn_mask, cache_kv, past_kvs[i] if past_kvs is not None else None, pe=pe, ctx_pe=ctx_pe, return_attn=return_attn)
if return_attn:
if cache_kv:
x, new_kvs, layer_attn = layer_output
kv_caches.append(new_kvs)
else:
x, layer_attn = layer_output
all_attn_maps.append(layer_attn)
elif cache_kv:
x, new_kvs = layer_output[0], layer_output[1]
kv_caches.append(new_kvs)
else:
x = layer_output
with torch.amp.autocast('cuda', enabled=False):
x = x.float()
if cond_emb is not None:
cond_emb = cond_emb.float()
if self.out_norm is not None:
x = self.out_norm(x, cond_emb)
if self.out_activate is not None:
x = self.out_activate(x)
if self.out is not None:
x = self.out(x)
if return_attn:
if cache_kv:
return x, kv_caches, all_attn_maps
else:
return x, all_attn_maps
elif cache_kv:
return x, kv_caches
else:
return x
def insert_eos_token(idx, z_len, eos_id):
"""Insert eos_id at position z_len in idx, returning a tensor with length +1."""
return torch.cat([idx[:, :z_len],
torch.full((idx.shape[0], 1), eos_id, device=idx.device, dtype=idx.dtype),
idx[:, z_len:]], dim=1)
@dataclass
class VQLossDetail:
quant_loss: Tensor
entropy_loss: Tensor
sample_entropy: Tensor
batch_entropy: Tensor
l2norm_z: Tensor
l2norm_code: Tensor
@staticmethod
def zero(device):
z = torch.zeros((), device=device)
return VQLossDetail(z, z, z, z, z, z)
@staticmethod
def from_tuple(t):
return VQLossDetail(*t)
@dataclass
class EncoderOutput:
quant: Tensor
indices: Tensor
one_hot: Tensor
semantic_vq_loss: VQLossDetail
visual_vq_loss: VQLossDetail
semantic_quant: Tensor
visual_quant: Tensor
semantic_indices: Tensor
visual_indices: Tensor
semantic_one_hot: Tensor = None
@dataclass
class AROutput:
logits: Tensor
semantic_logits: Tensor
visual_logits: Tensor
class Encoder(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.prologue = config.get("Prologue", False)
self.share_semantic_encoder = config.get("share_semantic_encoder", True)
self.share_semantic_codebook = config.get("share_semantic_codebook", False)
self.z_len = config["z_len"]
self.x_len = config["x_len"]
if self.prologue:
if self.share_semantic_encoder and config["Encoder"].get("context_type") != "concat":
warnings.warn(f"Prologue with shared backbone requires Encoder context_type='concat', got '{config['Encoder'].get('context_type')}'. Forcing to 'concat'.")
config["Encoder"]["context_type"] = "concat"
self.enc = Transformer(**config["Encoder"])
self.quantizer = PrologueQuantizer(**config["Quantizer"])
if self.prologue and not self.share_semantic_codebook:
self.semantic_quantizer = PrologueQuantizer(**config["SemanticQuantizer"])
if self.prologue and not self.share_semantic_encoder:
self.semantic_input_type = config.get("semantic_input_type", "encoder_output")
self.semantic_enc = Transformer(**config["SemanticEncoder"])
if not self.prologue or self.share_semantic_codebook:
self._forward_impl = self._forward_simple
self._encode_idx_impl = self._encode_idx_simple
elif self.share_semantic_encoder:
self._forward_impl = self._forward_split_codebook
self._encode_idx_impl = self._encode_idx_split_codebook
else:
self._forward_impl = self._forward_separate_enc
self._encode_idx_impl = self._encode_idx_separate_enc
def forward(self, x, labels, training=False) -> EncoderOutput:
return self._forward_impl(x, labels, training)
def encode_idx(self, x: torch.Tensor, labels=None) -> torch.Tensor:
return self._encode_idx_impl(x, labels)
# ---- strategy: simple (no Prologue, or Prologue with shared codebook) ----
def _forward_simple(self, x, labels, training):
h = self.enc(x, labels)
if self.enc.context_type == 'concat':
h = h[:, -self.enc.max_seq_len:]
quant, idx, one_hot, vqloss = self.quantizer(h, labels, training=training)
vl = VQLossDetail.from_tuple(vqloss) if isinstance(vqloss, tuple) else VQLossDetail(vqloss, *([torch.zeros((), device=quant.device)] * 5))
return EncoderOutput(
quant=quant, indices=idx, one_hot=one_hot,
semantic_vq_loss=None, visual_vq_loss=vl,
semantic_quant=None, visual_quant=quant,
semantic_indices=None, visual_indices=idx,
semantic_one_hot=None,
)
def _encode_idx_simple(self, x, labels):
z = self.enc(x, labels)
if self.enc.context_type == 'concat':
z = z[:, -self.enc.max_seq_len:]
return self.quantizer.encode(z, labels)
# ---- strategy: split codebook (Prologue, shared encoder, separate codebooks) ----
def _forward_split_codebook(self, x, labels, training):
h = self.enc(x, labels)
h_v, h_s = h[:, :self.x_len, :], h[:, self.x_len:, :]
quant_v, idx_v, oh_v, loss_v = self.quantizer(h_v, labels, training=training)
quant_s, idx_s, oh_s, loss_s = self.semantic_quantizer(h_s, labels, training=training)
raw_oh_s = oh_s
idx = torch.cat([idx_s, idx_v], dim=1)
if oh_s is not None and oh_v is not None:
max_cb = max(oh_s.shape[-1], oh_v.shape[-1])
if oh_s.shape[-1] < max_cb:
oh_s = F.pad(oh_s, (0, max_cb - oh_s.shape[-1]))
if oh_v.shape[-1] < max_cb:
oh_v = F.pad(oh_v, (0, max_cb - oh_v.shape[-1]))
one_hot = torch.cat([oh_s, oh_v], dim=1)
else:
one_hot = None
vl_s = VQLossDetail.from_tuple(loss_s) if isinstance(loss_s, tuple) else VQLossDetail(loss_s, *([torch.zeros((), device=quant_v.device)] * 5))
vl_v = VQLossDetail.from_tuple(loss_v) if isinstance(loss_v, tuple) else VQLossDetail(loss_v, *([torch.zeros((), device=quant_v.device)] * 5))
return EncoderOutput(
quant=quant_v, indices=idx, one_hot=one_hot,
semantic_vq_loss=vl_s, visual_vq_loss=vl_v,
semantic_quant=quant_s, visual_quant=quant_v,
semantic_indices=idx_s, visual_indices=idx_v,
semantic_one_hot=raw_oh_s,
)
def _encode_idx_split_codebook(self, x, labels):
z = self.enc(x, labels)
z_v, z_s = z[:, :self.x_len, :], z[:, self.x_len:, :]
idx_v = self.quantizer.encode(z_v, labels)
idx_s = self.semantic_quantizer.encode(z_s, labels)
return torch.cat([idx_s, idx_v], dim=1)
# ---- strategy: separate encoder (Prologue-Post, independent backbone) ----
def _forward_separate_enc(self, x, labels, training):
h = self.enc(x, labels)
vq_out = self.quantizer(h, labels, training=training,
return_continuous=(self.semantic_input_type == "pre_quant"))
if self.semantic_input_type == "pre_quant":
quant_v, idx_v, oh_v, loss_v, z_continuous = vq_out
sem_input = z_continuous
else:
quant_v, idx_v, oh_v, loss_v = vq_out
sem_input = h if self.semantic_input_type == "encoder_output" else x
h_s = self.semantic_enc(sem_input, labels)
if self.semantic_enc.context_type == 'concat':
h_s = h_s[:, -self.semantic_enc.max_seq_len:]
quant_s, idx_s, oh_s, loss_s = self.semantic_quantizer(h_s, labels, training=training)
raw_oh_s = oh_s
idx = torch.cat([idx_s, idx_v], dim=1)
if oh_s is not None and oh_v is not None:
max_cb = max(oh_s.shape[-1], oh_v.shape[-1])
if oh_s.shape[-1] < max_cb:
oh_s = F.pad(oh_s, (0, max_cb - oh_s.shape[-1]))
if oh_v.shape[-1] < max_cb:
oh_v = F.pad(oh_v, (0, max_cb - oh_v.shape[-1]))
one_hot = torch.cat([oh_s, oh_v], dim=1)
else:
one_hot = None
vl_s = VQLossDetail.from_tuple(loss_s) if isinstance(loss_s, tuple) else VQLossDetail(loss_s, *([torch.zeros((), device=quant_v.device)] * 5))
vl_v = VQLossDetail.from_tuple(loss_v) if isinstance(loss_v, tuple) else VQLossDetail(loss_v, *([torch.zeros((), device=quant_v.device)] * 5))
return EncoderOutput(
quant=quant_v, indices=idx, one_hot=one_hot,
semantic_vq_loss=vl_s, visual_vq_loss=vl_v,
semantic_quant=quant_s, visual_quant=quant_v,
semantic_indices=idx_s, visual_indices=idx_v,
semantic_one_hot=raw_oh_s,
)
def _encode_idx_separate_enc(self, x, labels):
h = self.enc(x, labels)
idx_v = self.quantizer.encode(h, labels)
if self.semantic_input_type == "encoder_output":
sem_input = h
elif self.semantic_input_type == "image_patch":
sem_input = x
else:
vq_out = self.quantizer(h, labels, training=False, return_continuous=True)
sem_input = vq_out[4]
h_s = self.semantic_enc(sem_input, labels)
if self.semantic_enc.context_type == 'concat':
h_s = h_s[:, -self.semantic_enc.max_seq_len:]
idx_s = self.semantic_quantizer.encode(h_s, labels)
return torch.cat([idx_s, idx_v], dim=1)
# ---- helper properties & methods ----
@property
def has_separate_semantic(self):
return self.prologue and not self.share_semantic_encoder
@property
def visual_modules(self):
return ["enc", "quantizer"]
@property
def semantic_modules(self):
return ["semantic_enc", "semantic_quantizer"] if self.has_separate_semantic else []
@property
def total_token_len(self):
if self.prologue and not self.share_semantic_codebook:
return self.z_len + self.x_len
return self.z_len
def get_visual_codes(self, indices, labels):
if self.prologue and not self.share_semantic_codebook:
visual_ids = indices[:, -self.x_len:]
return self.quantizer.get_codes_w_indices(visual_ids, labels)
return self.quantizer.get_codes_w_indices(indices, labels)
def visual_parameters(self):
return chain(self.enc.parameters(), self.quantizer.parameters())
def semantic_parameters(self):
return chain(self.semantic_enc.parameters(), self.semantic_quantizer.parameters())
class Decoder(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.dec = Transformer(**config["Decoder"])
self.is_concat = config["Decoder"].get("context_type", "none") == "concat"
self.output_len = config["Decoder"]["max_seq_len"]
def forward(self, quant, labels):
x_hat = self.dec(quant, labels)
if self.is_concat:
x_hat = x_hat[:, -self.output_len:]
return x_hat
class ARModel(nn.Module):
def __init__(self, config) -> None:
super().__init__()
ar_config = config["ARModel"]
self.conditional_injection = ar_config.get("conditional_injection", "llamagen") # dit | llamagen
# ste_ar_embedding: use one_hot @ W_emb for the prologue prefix so grads flow into the encoder.
self.ste_ar_embedding = ar_config.get("ste_ar_embedding", False)
self.num_classes = ar_config.cond_input_dim
print0("conditional_injection: ", self.conditional_injection)
print0("ste_ar_embedding: ", self.ste_ar_embedding)
print0("num_classes: ", self.num_classes)
prologue = config.get("Prologue", False) and not config.get("share_semantic_codebook", False)
self.z_len = int(config.get("z_len", 0)) if prologue else 0
vis_cb_size = int(config["Quantizer"]["codebook_size"])
sem_cb_size = int(config["SemanticQuantizer"]["codebook_size"]) if prologue else 0
self.ar_vocab_size = vis_cb_size + sem_cb_size
self.semantic_offset = vis_cb_size
self.semantic_codebook_size = sem_cb_size
self.visual_codebook_size = vis_cb_size
use_eos = bool(config.get("use_eos", False))
self.eos_len = 1 if (use_eos and self.z_len > 0) else 0
if self.eos_len > 0:
self.eos_token_id = self.ar_vocab_size
self.ar_vocab_size += 1
ar_config['max_seq_len'] = int(ar_config['max_seq_len']) + 1
if 'rope_hybrid_l_len' in ar_config and int(ar_config['rope_hybrid_l_len']) > 0:
ar_config['rope_hybrid_l_len'] = int(ar_config['rope_hybrid_l_len']) + 1
print0(f"use_eos: eos_token_id={self.eos_token_id}, max_seq_len→{ar_config['max_seq_len']}")
else:
self.eos_token_id = -1
print0(f"AR vocab: {self.ar_vocab_size}, semantic_offset: {self.semantic_offset}")
bos_num = self.num_classes if self.conditional_injection == "llamagen" else 1
self.bos_emb = nn.Embedding(bos_num, ar_config["dim"])
nn.init.trunc_normal_(self.bos_emb.weight, std=0.02)
self.semantic_emb = nn.Embedding(self.ar_vocab_size, ar_config["dim"])
nn.init.trunc_normal_(self.semantic_emb.weight, std=0.02)
self.tied_embedding = ar_config.get("tied_embedding", False)
if self.tied_embedding:
ar_config['out_layer'] = False
print0("tied_embedding: True, output layer disabled in Transformer")
ar_config['output_dim'] = self.ar_vocab_size
self.ar_model = Transformer(**ar_config)
self.temperature = ar_config["temperature"]
self.max_length = ar_config["max_seq_len"]
uncond_idx = self.num_classes - 1
self.register_buffer(
'uncond_ar_labels',
F.one_hot(torch.tensor([uncond_idx], dtype=torch.long), num_classes=self.num_classes).float()
)
self.register_buffer('logit_mask', None, persistent=False)
def forward(self, idx, labels, temperature=None, semantic_one_hot=None, return_attn=False) -> AROutput:
bz = idx.shape[0]
if self.conditional_injection == "llamagen":
bos = self.bos_emb(torch.argmax(labels, dim=1)).unsqueeze(1)
labels = self.uncond_ar_labels.expand(bz, -1).to(device=labels.device, dtype=labels.dtype)
else:
bos = self.bos_emb(torch.zeros(bz, device=idx.device, dtype=torch.long)).unsqueeze(1)
# STE: prefix embedding via one_hot @ W_emb so grads reach the encoder.
if self.ste_ar_embedding and semantic_one_hot is not None and self.z_len > 0:
sem_weight = self.semantic_emb.weight[self.semantic_offset:self.semantic_offset + self.semantic_codebook_size]
sem_token_emb = semantic_one_hot @ sem_weight
rest_token_emb = self.semantic_emb(idx[:, self.z_len:-1])
token_emb = torch.cat([sem_token_emb, rest_token_emb], dim=1)
elif self.ste_ar_embedding and semantic_one_hot is not None:
weight = self.semantic_emb.weight[:self.visual_codebook_size]
token_emb = semantic_one_hot[:, :-1] @ weight
else:
token_emb = self.semantic_emb(idx[:, :-1])
shift_input = torch.cat([bos, token_emb], dim=1)
result = self.ar_model(shift_input, labels, return_attn=return_attn)
if return_attn:
out, all_attn_maps = result
else:
out = result
logits = F.linear(out, self.semantic_emb.weight) if self.tied_embedding else out
if temperature is not None:
logits = logits / temperature
elif self.temperature is not None:
logits = logits / self.temperature
if self.z_len > 0:
semantic_logits = logits[:, :self.z_len]
visual_logits = logits[:, self.z_len + self.eos_len:]
else:
semantic_logits = logits
visual_logits = logits
ar_output = AROutput(logits=logits, semantic_logits=semantic_logits, visual_logits=visual_logits)
if return_attn:
return ar_output, all_attn_maps
return ar_output
def set_logit_mask(self, mask):
"""Install per-step ``logit_mask`` buffer; pads EOS row/slot at ``z_len`` when ``use_eos``."""
if self.eos_len > 0:
if mask is not None:
mask = F.pad(mask, (0, 1), value=float('-inf'))
eos_row = torch.full((1, self.ar_vocab_size), float('-inf'))
eos_row[0, self.eos_token_id] = 0.
mask = torch.cat([mask[:self.z_len], eos_row, mask[self.z_len:]], dim=0)
else:
mask = torch.zeros(self.max_length, self.ar_vocab_size)
mask[:, self.eos_token_id] = float('-inf')
mask[self.z_len] = float('-inf')
mask[self.z_len, self.eos_token_id] = 0.
self.register_buffer('logit_mask', mask, persistent=False)
@torch.no_grad()
@torch._dynamo.disable
def sampling(self, bz, class_label=None, temperature=1.0, topK=None, topP=None, cfg=16.0, cfg_schedule='cosine', cfg_power=2.75, cache_kv=False,
semantic_cfg_schedule=None, semantic_cfg_scale=None, semantic_cfg_power=None, semantic_cfg_start=0.0,
visual_cfg_schedule=None, visual_cfg_scale=None, visual_cfg_power=None, visual_cfg_start=1.0,
semantic_temperature=None):
cfg = 0. if class_label is None else cfg
use_segmented = (semantic_cfg_schedule is not None or visual_cfg_schedule is not None)
use_cfg = cfg > 0.
if use_segmented:
_ss = semantic_cfg_scale if semantic_cfg_scale is not None else cfg
_vs = visual_cfg_scale if visual_cfg_scale is not None else cfg
use_cfg = _ss > 0. or _vs > 0.
uncond_idx = int(self.ar_model.cond_input_dim) - 1
device = self.bos_emb.weight.device
if self.conditional_injection == "llamagen":
cond_bos = self.bos_emb(torch.argmax(class_label, dim=1)).unsqueeze(1) # [B, 1, D]
uncond_bos = self.bos_emb(torch.full((bz,), uncond_idx, device=device, dtype=torch.long)).unsqueeze(1)
ar_labels = self.uncond_ar_labels.expand(bz, -1).to(device=device) # [B, num_classes]
uncond_labels = self.uncond_ar_labels.expand(bz, -1).to(device=device)
else:
cond_bos = self.bos_emb(torch.zeros(bz, device=device, dtype=torch.long)).unsqueeze(1)
uncond_bos = cond_bos
ar_labels = class_label
uncond_labels = self.uncond_ar_labels.expand(bz, -1).to(device=device)
quant_input = torch.cat([cond_bos, uncond_bos], dim=0) if use_cfg else cond_bos
ar_labels = torch.cat([ar_labels, uncond_labels], dim=0) if use_cfg else ar_labels
quant_output = []
past_kvs = None
for step in range(self.max_length):
# CFG
if use_cfg:
ar_out = self.ar_model(quant_input, ar_labels, cache_kv=cache_kv, past_kvs=past_kvs)
if cache_kv:
hidden_all, past_kvs = ar_out
else:
hidden_all = ar_out
if self.tied_embedding:
hidden_all = F.linear(hidden_all[:, -1:], self.semantic_emb.weight)
logits_all = hidden_all[:, -1]
logits, uncond_logits = logits_all.chunk(2, dim=0)
is_semantic_step = (self.z_len > 0 and step < self.z_len)
if use_segmented:
if is_semantic_step:
seg_schedule = semantic_cfg_schedule or 'constant'
seg_scale = semantic_cfg_scale if semantic_cfg_scale is not None else cfg
seg_power = semantic_cfg_power if semantic_cfg_power is not None else cfg_power
seg_start = semantic_cfg_start
seg_t = step / self.z_len if self.z_len > 0 else 0.0
else:
seg_schedule = visual_cfg_schedule or 'constant'
seg_scale = visual_cfg_scale if visual_cfg_scale is not None else cfg
seg_power = visual_cfg_power if visual_cfg_power is not None else cfg_power
seg_start = visual_cfg_start
visual_start = self.z_len
visual_len = self.max_length - visual_start
seg_t = (step - visual_start) / visual_len if visual_len > 0 else 0.0
if seg_schedule == 'constant':
cfg_scale = seg_scale
elif seg_schedule == 'linear':
cfg_scale = seg_start + (seg_scale - seg_start) * seg_t
elif seg_schedule == 'cosine':
shape = (1 - math.cos(
(seg_t ** seg_power) * math.pi)) * 0.5
cfg_scale = seg_start + (seg_scale - seg_start) * shape
else:
raise ValueError(f"Invalid segmented cfg schedule: {seg_schedule}")
elif cfg_schedule == 'constant':
cfg_scale = cfg
elif cfg_schedule == 'linear':
cfg_scale = 1.0 * (1-step/self.max_length) + cfg * (step/self.max_length)
elif cfg_schedule == 'cosine':
cfg_scale = (1 - math.cos(
((step / self.max_length) ** cfg_power) * math.pi)) * 1/2
cfg_scale = (cfg - 1) * cfg_scale + 1
else:
raise ValueError(f"Invalid cfg_schedule: {cfg_schedule}")
logits = cfg_scale * logits + (1 - cfg_scale) * uncond_logits
else:
ar_out = self.ar_model(quant_input, ar_labels, cache_kv=cache_kv, past_kvs=past_kvs)
if cache_kv:
hidden, past_kvs = ar_out
else:
hidden = ar_out
if self.tied_embedding:
hidden = F.linear(hidden[:, -1:], self.semantic_emb.weight)
logits = hidden[:, -1]
is_semantic_step_temp = (self.z_len > 0 and step < self.z_len)
t = semantic_temperature if (semantic_temperature is not None and is_semantic_step_temp) else temperature
logits = logits / t
if self.logit_mask is not None:
logits = logits + self.logit_mask[step]
# Top-K filtering
if topK is not None and topK > 0.:
top_logits, top_indices = logits.topk(topK, dim=-1)
logits = torch.full_like(logits, float('-inf'))
logits.scatter_(dim=-1, index=top_indices, src=top_logits)
# Top-P (nucleus) filtering
if topP is not None and 0. < topP < 1.:
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
probs_sum = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
mask = probs_sum > topP
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
sorted_logits[mask] = float('-inf')
logits = torch.full_like(logits, float('-inf'))
logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)
# Sample the next token
with torch.amp.autocast("cuda", enabled=False):
next_idx = torch.multinomial(F.softmax(logits.float(), dim=-1), 1)
next_idx = next_idx.to(dtype=torch.long)
quant_output.append(next_idx)
# Feed the new token back as the next input embedding
next_emb = self.semantic_emb(next_idx)
if use_cfg:
next_emb = torch.cat([next_emb, next_emb], dim=0) # [2B, 1, D]
if not cache_kv:
quant_input = torch.cat((quant_input, next_emb), dim=1)
else:
quant_input = next_emb
quant_output = torch.cat(quant_output, dim=1)
return quant_output
# deterministic Hard-Gumbel Softmax Quantizer
class L2Normalize(nn.Module):
"""Wrapper for F.normalize to make it a proper nn.Module"""
def __init__(self, dim=-1):
super().__init__()
self.dim = dim
def forward(self, x, element_wise=False):
if not element_wise:
return F.normalize(x, dim=self.dim)
else:
dtype_one = torch.Tensor([1.0]).to(x.device).to(x.dtype)
return torch.where(x>0, dtype_one, -dtype_one)
class PrologueQuantizer(nn.Module):
"""VQ with optional STE variants (``VQ`` / ``LFQ`` / ``ProbQ``) and per-position ``pos_select_mask``."""
def __init__(self, codebook_size, dim, z_dim, cond_dim,
codebook_init='trunc_normal',
z_norm_type='none', # AdaLN, AdaRMSNorm, FixLN, FixRMSNorm, l2, none
temperature=1.0,
frozen_codebook=False, codebook_proj='none',
ste='ProbQ',
length=256,
vq_beta=0.25, use_quant_loss=False, use_entropy_loss=False, monitor_metrics=True,
**params):
super().__init__()
self.monitor_metrics = monitor_metrics
self.codebook_size = codebook_size # codebook size (K)
self.dim = dim # embedding dimension (D)
self.z_dim = z_dim
self.length = int(length)
self.proj_in = Linear(dim, z_dim)
elementwise_affine = True
cond_based_affine = False
self.cond_proj = None
if 'Ada' in z_norm_type:
self.cond_proj = nn.Sequential(Linear(1001, cond_dim), AdaRMSNorm(cond_dim, cond_based_affine=False), nn.GELU(approximate='tanh'))
elementwise_affine = True
cond_based_affine = True
if 'Fix' in z_norm_type:
cond_based_affine = False
elementwise_affine = False
self.z_norm_type = z_norm_type
self.z_norm = nn.Identity()
if 'l2' in z_norm_type:
self.z_norm = L2Normalize(dim=-1)
elif 'RMSNorm' in z_norm_type:
self.z_norm = AdaRMSNorm(z_dim, cond_dim, elementwise_affine=elementwise_affine, cond_based_affine=cond_based_affine)
elif 'LN' in z_norm_type:
self.z_norm = AdaLN(z_dim, cond_dim, elementwise_affine=elementwise_affine, cond_based_affine=cond_based_affine)
if 'LFQ' in z_norm_type:
self.z_norm = L2Normalize(dim=-1, element_wise=True)
self.codebook_init = codebook_init
if self.codebook_init == 'LFQ':
self.register_buffer("indices_map", 2**torch.arange(z_dim).view(1, 1, -1)) # B L D
self.register_buffer("embedding", self.indices_to_emb(torch.arange(codebook_size).view(-1, 1, 1)).view(-1, z_dim)) # N D
else:
self.embedding = nn.Embedding(codebook_size, z_dim)
if codebook_init == 'trunc_normal':
torch.nn.init.trunc_normal_(self.embedding.weight, std=0.02)
elif codebook_init == 'uniform':
self.embedding.weight.data.uniform_(-np.sqrt(1 / codebook_size), np.sqrt(1 / codebook_size))
elif codebook_init == 'simvq':
self.embedding.weight.data.normal_(mean=0, std=self.z_dim**-0.5)
if frozen_codebook:
self.embedding.requires_grad_(False)
self.codebook_proj = nn.Identity()
if codebook_proj == 'linear':
self.codebook_proj = Linear(z_dim, z_dim)
elif codebook_proj == 'norm_linear':
self.codebook_proj = nn.Sequential(AdaRMSNorm(z_dim, elementwise_affine=True, cond_based_affine=False), Linear(z_dim, z_dim))
elif codebook_proj == 'linear_norm':
self.codebook_proj = nn.Sequential(Linear(z_dim, z_dim), AdaRMSNorm(z_dim, elementwise_affine=True, cond_based_affine=False))
elif codebook_proj == 'mlp':
self.codebook_proj = nn.Sequential(Linear(z_dim, dim), nn.GELU(approximate='tanh'), Linear(dim, z_dim))
elif codebook_proj == 'mlp_with_norm':
self.codebook_proj = nn.Sequential(Linear(z_dim, dim), AdaRMSNorm(dim, elementwise_affine=True, cond_based_affine=False), nn.GELU(approximate='tanh'), Linear(dim, z_dim))
elif codebook_proj == 'l2':
self.codebook_proj = L2Normalize(dim=-1)
self.ste = ste
self.vq_beta = vq_beta
self.use_quant_loss = use_quant_loss
self.use_entropy_loss = use_entropy_loss
self.temperature = temperature
print0(f"z_norm_type: {z_norm_type}")
print0(f"codebook_proj: {codebook_proj}")
print0(f"codebook_init: {codebook_init}")
print0(f"frozen_codebook: {frozen_codebook}")
print0(f"temperature: {temperature}")
print0(f"ste: {ste}")
# All-zero pos_select_mask; AR sampling composes a global mask (see utils.build_ar_logit_mask).
mask_bool = torch.ones(self.length, codebook_size, dtype=torch.bool)
self.register_buffer('pos_select_mask', torch.where(mask_bool, 0., float('-inf')))
def indices_to_emb(self, indices):
return ((indices.int() & self.indices_map) != 0).float() * 2. - 1.
@property
def codebook(self):
if self.codebook_init == 'LFQ':
return self.embedding
else:
return self.embedding.weight
def _get_pos_mask(self, L):
mask = self.pos_select_mask
return mask[:L] if L < mask.shape[0] else mask
@autocast("cuda", enabled=False)
def forward(self, x, labels=None, training=False, log_usage=False, indices=None, return_continuous=False):
B, L, D = x.shape
x = x.float()
labels = self.cond_proj(labels.float()) if self.cond_proj is not None and labels is not None else labels
z = self.proj_in(x)
z_normed = self.z_norm(z, labels) if 'Ada' in self.z_norm_type else self.z_norm(z)
codebook = self.codebook
codebook_normed = self.codebook_proj(codebook)
logits = torch.einsum('bld,nd->bln', z_normed, codebook_normed)
prob = F.softmax(logits / self.temperature, dim=-1)
pos_mask = self._get_pos_mask(L)
indices = torch.argmax(prob + pos_mask, dim=-1) if indices is None else indices
one_hot_ng = F.one_hot(indices, self.codebook_size).view(B, L, -1).to(z.device).to(z.dtype)
one_hot = prob + (one_hot_ng - prob).detach()
if self.ste == 'VQ':
quant = codebook_normed[indices]
quant_ste = z_normed + (quant - z_normed).detach()
elif self.ste == 'LFQ':
quant = codebook[indices]
quant_ste = z + (quant - z).detach()
elif self.ste == 'ProbQ':
quant_ste = torch.einsum('bln,nd->bld', one_hot, codebook_normed)
quant_loss = torch.tensor(0., device=z.device, dtype=z.dtype)
if self.monitor_metrics or (training and self.use_quant_loss):
if self.ste == 'VQ':
quant_loss = self.vq_beta * (quant.detach() - z_normed).pow(2).mean() + (quant - z_normed.detach()).pow(2).mean()
elif self.ste == 'LFQ':
quant_loss = self.vq_beta * torch.mean((quant.detach() - z)**2) + torch.mean((quant - z.detach())**2)
elif self.ste == 'ProbQ':
quant_ng = torch.einsum('bln,nd->bld', one_hot_ng, codebook_normed)
quant_loss = torch.mean((quant_ste - z_normed)**2) + self.vq_beta * torch.mean((quant_ng.detach() - z_normed)**2) + torch.mean((quant_ng - z_normed.detach())**2)
quant_loss = quant_loss.detach() if not (training and self.use_quant_loss) else quant_loss
sample_entropy = torch.tensor(0., device=z.device, dtype=z.dtype)
batch_entropy = torch.tensor(0., device=z.device, dtype=z.dtype)
entropy_loss = torch.tensor(0., device=z.device, dtype=z.dtype)
if self.monitor_metrics or (training and self.use_entropy_loss):
masked_logits = logits + pos_mask
sample_entropy, batch_entropy, entropy_loss = compute_entropy_loss(logits=masked_logits.reshape(-1, self.codebook_size))
sample_entropy = sample_entropy.detach() if not (training and self.use_entropy_loss) else sample_entropy
batch_entropy = batch_entropy.detach() if not (training and self.use_entropy_loss) else batch_entropy
entropy_loss = entropy_loss.detach() if not (training and self.use_entropy_loss) else entropy_loss
l2norm_code = torch.tensor(0., device=z.device, dtype=z.dtype)
l2norm_z = torch.tensor(0., device=z.device, dtype=z.dtype)
if self.monitor_metrics:
l2norm_code = torch.norm(codebook_normed, p=2, dim=-1).mean().detach()
l2norm_z = torch.norm(z_normed, p=2, dim=-1).mean().detach()
loss_tuple = (quant_loss, entropy_loss, sample_entropy, batch_entropy, l2norm_z, l2norm_code)
if return_continuous:
return quant_ste, indices, one_hot, loss_tuple, z_normed
return quant_ste, indices, one_hot, loss_tuple
@autocast("cuda", enabled=False)
def encode(self, x: torch.Tensor, labels=None):
B, L, D = x.shape
x = x.float()
labels = self.cond_proj(labels.float()) if self.cond_proj is not None and labels is not None else labels
z = self.proj_in(x)
z_normed = self.z_norm(z, labels) if 'Ada' in self.z_norm_type else self.z_norm(z)
codebook = self.codebook
codebook_normed = self.codebook_proj(codebook)
logits = torch.einsum('bld,nd->bln', z_normed, codebook_normed)
indices = torch.argmax(logits + self._get_pos_mask(L), dim=-1)
return indices
@autocast("cuda", enabled=False)
def get_codes_w_indices(self, indices, labels=None, **params):
if self.codebook_init == 'LFQ':
codes = self.indices_to_emb(indices)
else:
codes = self.embedding(indices)
codes = self.codebook_proj(codes)
return codes
def compute_entropy_loss(
logits,
temperature=0.01,
sample_minimization_weight=1.0,
batch_maximization_weight=1.0,
eps=1e-5,
):
"""Entropy loss on logits (affinities over the last dim); from MAGVIT (Yu et al., 2024)."""
with torch.amp.autocast("cuda", enabled=False):
probs = F.softmax(logits.float() / temperature, -1)
log_probs = F.log_softmax(logits.float() / temperature + eps, -1)
probs = probs.to(logits.dtype)
log_probs = log_probs.to(logits.dtype)
avg_probs = reduce(probs, "... D -> D", "mean")
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps))
sample_entropy = -torch.sum(torch.nan_to_num(probs * log_probs, nan=0.0), -1)
sample_entropy = torch.mean(sample_entropy)
loss = (sample_minimization_weight * sample_entropy) - (
batch_maximization_weight * avg_entropy
)
return sample_entropy, avg_entropy, loss