cond_gen / decoders.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
import torch
from typing import Optional, Tuple
from torch import nn as nn
from transformers.cache_utils import Cache
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
class AdaLN(nn.Module):
"""
DiT-style AdaLN:
cond_token -> (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp)
If zero_init=True, then at step0:
shift/scale/gate are all exactly 0 -> base behavior preserved (mathematically).
"""
def __init__(
self,
hidden_size: int,
cond_dim: int,
zero_init: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.act = nn.SiLU()
self.linear = nn.Linear(cond_dim, 6 * hidden_size, bias=True)
if zero_init:
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, cond_token: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""
cond_token: [B, T, cond_dim]
returns 6 tensors, each [B, T, H]
"""
params = self.linear(self.act(cond_token)) # [B, T, 6H]
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = params.chunk(6, dim=-1)
return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
def apply_adaln(
x_norm: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
# x_norm * (1 + scale) + shift
return x_norm * (1.0 + scale) + shift
class Qwen3DecoderLayerAdaLN(Qwen3DecoderLayer):
"""
Qwen3 decoder layer with AdaLN injection:
- Modulate normalized input with (shift, scale) on masked positions.
- IMPORTANT: gate must preserve base behavior at gate=0:
out = out_base * (1 + gate) (on masked positions)
so that when gate==0, out==out_base.
Only applied on audio positions (condition_mask==True).
"""
def __init__(
self,
config,
layer_idx: int,
cond_dim: int,
zero_init: bool = True,
):
super().__init__(config, layer_idx)
self.dit_adaln = AdaLN(
hidden_size=config.hidden_size,
cond_dim=cond_dim,
zero_init=zero_init,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cond_expanded: Optional[torch.Tensor] = None, # [B, T, cond_dim]
condition_mask: Optional[torch.BoolTensor] = None, # [B, T]
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
):
# Keep the condition path fully tensor-based; avoid .item() checks that
# can force GPU-CPU synchronization in autoregressive decoding.
do_cond = (cond_expanded is not None) and (condition_mask is not None)
if do_cond:
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.dit_adaln(cond_expanded)
mask_expanded = condition_mask.unsqueeze(-1) # [B, T, 1]
# ---- Self-Attention branch ----
residual = hidden_states
x_norm = self.input_layernorm(hidden_states) # RMSNorm in Qwen3
if do_cond:
x_mod = apply_adaln(x_norm, shift_msa, scale_msa)
x_in = torch.where(mask_expanded, x_mod, x_norm)
else:
x_in = x_norm
attn_out, _ = self.self_attn(
hidden_states=x_in,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
if do_cond:
# Preserve base when gate==0: attn_out_audio = (1 + gate) * attn_out_base
attn_out = torch.where(mask_expanded, (1.0 + gate_msa) * attn_out, attn_out)
hidden_states = residual + attn_out
# ---- MLP branch ----
residual = hidden_states
x_norm = self.post_attention_layernorm(hidden_states)
if do_cond:
x_mod = apply_adaln(x_norm, shift_mlp, scale_mlp)
x_in = torch.where(mask_expanded, x_mod, x_norm)
else:
x_in = x_norm
mlp_out = self.mlp(x_in)
if do_cond:
# Preserve base when gate==0
mlp_out = torch.where(mask_expanded, (1.0 + gate_mlp) * mlp_out, mlp_out)
hidden_states = residual + mlp_out
return hidden_states