import torch from typing import Optional, Tuple from torch import nn as nn from transformers.cache_utils import Cache from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer class AdaLN(nn.Module): """ DiT-style AdaLN: cond_token -> (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) If zero_init=True, then at step0: shift/scale/gate are all exactly 0 -> base behavior preserved (mathematically). """ def __init__( self, hidden_size: int, cond_dim: int, zero_init: bool = True, ): super().__init__() self.hidden_size = hidden_size self.act = nn.SiLU() self.linear = nn.Linear(cond_dim, 6 * hidden_size, bias=True) if zero_init: nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias) def forward(self, cond_token: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ cond_token: [B, T, cond_dim] returns 6 tensors, each [B, T, H] """ params = self.linear(self.act(cond_token)) # [B, T, 6H] ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, ) = params.chunk(6, dim=-1) return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp def apply_adaln( x_norm: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor ) -> torch.Tensor: # x_norm * (1 + scale) + shift return x_norm * (1.0 + scale) + shift class Qwen3DecoderLayerAdaLN(Qwen3DecoderLayer): """ Qwen3 decoder layer with AdaLN injection: - Modulate normalized input with (shift, scale) on masked positions. - IMPORTANT: gate must preserve base behavior at gate=0: out = out_base * (1 + gate) (on masked positions) so that when gate==0, out==out_base. Only applied on audio positions (condition_mask==True). """ def __init__( self, config, layer_idx: int, cond_dim: int, zero_init: bool = True, ): super().__init__(config, layer_idx) self.dit_adaln = AdaLN( hidden_size=config.hidden_size, cond_dim=cond_dim, zero_init=zero_init, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, cond_expanded: Optional[torch.Tensor] = None, # [B, T, cond_dim] condition_mask: Optional[torch.BoolTensor] = None, # [B, T] position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ): # Keep the condition path fully tensor-based; avoid .item() checks that # can force GPU-CPU synchronization in autoregressive decoding. do_cond = (cond_expanded is not None) and (condition_mask is not None) if do_cond: ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, ) = self.dit_adaln(cond_expanded) mask_expanded = condition_mask.unsqueeze(-1) # [B, T, 1] # ---- Self-Attention branch ---- residual = hidden_states x_norm = self.input_layernorm(hidden_states) # RMSNorm in Qwen3 if do_cond: x_mod = apply_adaln(x_norm, shift_msa, scale_msa) x_in = torch.where(mask_expanded, x_mod, x_norm) else: x_in = x_norm attn_out, _ = self.self_attn( hidden_states=x_in, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) if do_cond: # Preserve base when gate==0: attn_out_audio = (1 + gate) * attn_out_base attn_out = torch.where(mask_expanded, (1.0 + gate_msa) * attn_out, attn_out) hidden_states = residual + attn_out # ---- MLP branch ---- residual = hidden_states x_norm = self.post_attention_layernorm(hidden_states) if do_cond: x_mod = apply_adaln(x_norm, shift_mlp, scale_mlp) x_in = torch.where(mask_expanded, x_mod, x_norm) else: x_in = x_norm mlp_out = self.mlp(x_in) if do_cond: # Preserve base when gate==0 mlp_out = torch.where(mask_expanded, (1.0 + gate_mlp) * mlp_out, mlp_out) hidden_states = residual + mlp_out return hidden_states