from typing import * from functools import partial import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from ..modules.utils import convert_module_to, manual_cast, str_to_dtype from ..modules.transformer import AbsolutePositionEmbedder from ..modules import sparse as sp from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock from .sparse_structure_flow import TimestepEmbedder from .sparse_elastic_mixin import SparseTransformerElasticMixin class SLatFlowModel(nn.Module): """ Structured Latent Flow Model for 3D generation. Supports two conditioning modes: - "cross": Standard cross-attention with image features - "proj": View-aligned projection attention with camera-aware features """ def __init__( self, resolution: int, in_channels: int, model_channels: int, cond_channels: int, out_channels: int, num_blocks: int, num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, pe_mode: Literal["ape", "rope"] = "ape", rope_freq: Tuple[float, float] = (1.0, 10000.0), dtype: str = 'float32', use_checkpoint: bool = False, share_mod: bool = False, initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, image_attn_mode: Literal["cross", "proj", "gated_proj"] = "cross", proj_in_channels: Optional[int] = None, vae_in_channels: Optional[int] = None, **kwargs ): super().__init__() self.resolution = resolution self.in_channels = in_channels self.model_channels = model_channels self.cond_channels = cond_channels self.out_channels = out_channels self.num_blocks = num_blocks self.num_heads = num_heads or model_channels // num_head_channels self.mlp_ratio = mlp_ratio self.pe_mode = pe_mode self.use_checkpoint = use_checkpoint self.share_mod = share_mod self.initialization = initialization self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross self.image_attn_mode = image_attn_mode self.proj_in_channels = proj_in_channels self.vae_in_channels = vae_in_channels self.dtype = str_to_dtype(dtype) self.t_embedder = TimestepEmbedder(model_channels) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True) ) if pe_mode == "ape": self.pos_embedder = AbsolutePositionEmbedder(model_channels) self.input_layer = sp.SparseLinear(in_channels, model_channels) self.blocks = nn.ModuleList([ ModulatedSparseTransformerCrossBlock( model_channels, cond_channels, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, attn_mode='full', use_checkpoint=self.use_checkpoint, use_rope=(pe_mode == "rope"), rope_freq=rope_freq, share_mod=self.share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, image_attn_mode=image_attn_mode, proj_in_channels=proj_in_channels, vae_in_channels=vae_in_channels, ) for _ in range(num_blocks) ]) self.out_layer = sp.SparseLinear(model_channels, out_channels) self.initialize_weights() self.convert_to(self.dtype) @property def device(self) -> torch.device: """ Return the device of the model. """ return next(self.parameters()).device def convert_to(self, dtype: torch.dtype) -> None: """ Convert the torso of the model to the specified dtype. """ self.dtype = dtype self.blocks.apply(partial(convert_module_to, dtype=dtype)) def initialize_weights(self) -> None: if self.initialization == 'vanilla': # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in DiT blocks: if self.share_mod: nn.init.constant_(self.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.adaLN_modulation[-1].bias, 0) else: for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: nn.init.constant_(self.out_layer.weight, 0) nn.init.constant_(self.out_layer.bias, 0) elif self.initialization == 'scaled': # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Scaled init for to_out and ffn2 def _scaled_init(module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) if module.bias is not None: nn.init.constant_(module.bias, 0) for block in self.blocks: block.self_attn.to_out.apply(_scaled_init) # Handle cross, proj, and gated_proj modes if self.image_attn_mode in ("proj", "gated_proj"): block.cross_attn.cross_attn_block.to_out.apply(_scaled_init) else: block.cross_attn.to_out.apply(_scaled_init) block.mlp.mlp[2].apply(_scaled_init) # Initialize input layer to make the initial representation have variance 1 nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) nn.init.zeros_(self.input_layer.bias) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in DiT blocks: if self.share_mod: nn.init.constant_(self.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.adaLN_modulation[-1].bias, 0) else: for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: nn.init.constant_(self.out_layer.weight, 0) nn.init.constant_(self.out_layer.bias, 0) def forward( self, x: sp.SparseTensor, t: torch.Tensor, cond: Union[torch.Tensor, List[torch.Tensor], Dict[str, Union[torch.Tensor, sp.SparseTensor]], Tuple], concat_cond: Optional[sp.SparseTensor] = None, **kwargs ) -> sp.SparseTensor: """ Forward pass. Args: x: SparseTensor input t: Timestep tensor [B] cond: Conditioning tensor. For "cross" mode: list of tensors or tensor. For "proj" mode: dict {'global': global_cond, 'proj': proj_cond} or tuple of (global_cond, proj_cond) concat_cond: Optional concatenation condition Returns: SparseTensor output """ if concat_cond is not None: x = sp.sparse_cat([x, concat_cond], dim=-1) h = self.input_layer(x) h = manual_cast(h, self.dtype) t_emb = self.t_embedder(t) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) if self.pe_mode == "ape": pe = self.pos_embedder(h.coords[:, 1:]) h = h + manual_cast(pe, self.dtype) # Handle different conditioning modes if self.image_attn_mode == 'proj': if isinstance(cond, dict): global_cond = cond['global'] proj_cond = cond['proj'] else: global_cond, proj_cond = cond if isinstance(global_cond, list): global_cond = sp.VarLenTensor.from_tensor_list(global_cond) global_cond = manual_cast(global_cond, self.dtype) proj_cond = manual_cast(proj_cond, self.dtype) cond = (global_cond, proj_cond) elif self.image_attn_mode == 'gated_proj': global_cond = cond['global'] if isinstance(global_cond, list): global_cond = sp.VarLenTensor.from_tensor_list(global_cond) global_cond = manual_cast(global_cond, self.dtype) proj_semantic = manual_cast(cond['proj_semantic'], self.dtype) proj_color = manual_cast(cond['proj_color'], self.dtype) cond = {'global': global_cond, 'proj_semantic': proj_semantic, 'proj_color': proj_color} else: if isinstance(cond, list): cond = sp.VarLenTensor.from_tensor_list(cond) cond = manual_cast(cond, self.dtype) for block in self.blocks: h = block(h, t_emb, cond) h = manual_cast(h, x.dtype) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.out_layer(h) return h class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel): """ SLat Flow Model with elastic memory management. Used for training with low VRAM. """ pass