Pixal3D / trellis2 /modules /sparse /attention /proj_attention.py
Yang2001's picture
Upload folder using huggingface_hub
8d595ff verified
"""
Sparse View-Aligned Projection Attention Module for TRELLIS2
Sparse versions of ProjectAttention and GatedProjectAttention.
Supports two modes:
- "proj": Standard projection (DINOv3 only)
- "gated_proj": Gated fusion of DINOv3 (semantic) + VAE (color) features
"""
from typing import *
import torch
import torch.nn as nn
from ..basic import SparseTensor, VarLenTensor
class SparseProjectAttention(nn.Module):
"""
Sparse Projection-based Attention Module with per-block proj_linear.
"""
def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int):
super().__init__()
self.cross_attn_block = cross_attn_block
self.proj_linear = nn.Linear(proj_in_channels, channels, bias=True)
def forward(
self,
x: SparseTensor,
context: Union[Dict[str, Union[torch.Tensor, VarLenTensor, SparseTensor]],
Tuple[Union[torch.Tensor, VarLenTensor], SparseTensor]]
) -> SparseTensor:
if isinstance(context, dict):
global_context = context['global']
proj_context = context['proj']
else:
global_context, proj_context = context
global_out = self.cross_attn_block(x, global_context)
if isinstance(proj_context, SparseTensor):
proj_feats = self.proj_linear(proj_context.feats)
combined_feats = proj_feats + global_out.feats
else:
proj_feats = self.proj_linear(proj_context)
combined_feats = proj_feats + global_out.feats
return global_out.replace(combined_feats)
class SparseGatedProjectAttention(nn.Module):
"""
Sparse Concat-Projection Attention Module for DINOv3 + VAE features.
Concatenates DINOv3 and VAE projected features and applies a single linear
projection to model_channels. Zero-initialized for stable training.
Context dict must contain:
- 'global': Global image features for cross-attention
- 'proj_semantic': DINOv3 projected features (SparseTensor or Tensor)
- 'proj_color': VAE projected features (SparseTensor or Tensor)
"""
def __init__(
self,
cross_attn_block: nn.Module,
channels: int,
dino_in_channels: int,
vae_in_channels: int,
):
super().__init__()
self.cross_attn_block = cross_attn_block
self.proj_linear = nn.Linear(dino_in_channels + vae_in_channels, channels, bias=True)
# Zero-init: at start, fused=0, only global cross-attn contributes
nn.init.zeros_(self.proj_linear.weight)
nn.init.zeros_(self.proj_linear.bias)
def _get_feats(self, t):
return t.feats if isinstance(t, SparseTensor) else t
def forward(
self,
x: SparseTensor,
context: Union[Dict[str, Union[torch.Tensor, VarLenTensor, SparseTensor]], Tuple],
) -> SparseTensor:
if isinstance(context, dict):
global_context = context['global']
proj_semantic = context['proj_semantic']
proj_color = context['proj_color']
else:
global_context, proj_semantic, proj_color = context
global_out = self.cross_attn_block(x, global_context)
fused = self.proj_linear(torch.cat([
self._get_feats(proj_semantic),
self._get_feats(proj_color),
], dim=-1))
combined_feats = fused + global_out.feats
return global_out.replace(combined_feats)