text2sign / models /unet3d.py
xiaruize's picture
upd
234a70c
"""
3D UNet architecture for video diffusion with text conditioning
Enhanced with Transformer (DiT-style) blocks for better temporal modeling
Based on:
- Diffusion Transformers (DiT) - Peebles & Xie 2023
- Video diffusion models with temporal attention
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
return emb
def get_3d_sincos_pos_embed(embed_dim: int, grid_size: Tuple[int, int, int]) -> torch.Tensor:
"""
Generate 3D sinusoidal positional embeddings for video (T, H, W).
"""
t, h, w = grid_size
grid_t = torch.arange(t, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w = torch.arange(w, dtype=torch.float32)
grid = torch.meshgrid(grid_t, grid_h, grid_w, indexing='ij')
grid = torch.stack(grid, dim=0) # (3, T, H, W)
grid = grid.reshape(3, -1).T # (T*H*W, 3)
# Split embedding dim across 3 dimensions
dim_t = embed_dim // 3
dim_h = embed_dim // 3
dim_w = embed_dim - dim_t - dim_h
def get_1d_sincos(positions, dim):
omega = torch.arange(dim // 2, dtype=torch.float32)
omega = 1.0 / (10000 ** (omega / (dim // 2)))
out = positions[:, None] * omega[None, :]
return torch.cat([torch.sin(out), torch.cos(out)], dim=1)
emb_t = get_1d_sincos(grid[:, 0], dim_t)
emb_h = get_1d_sincos(grid[:, 1], dim_h)
emb_w = get_1d_sincos(grid[:, 2], dim_w)
return torch.cat([emb_t, emb_h, emb_w], dim=1) # (T*H*W, embed_dim)
class GroupNorm32(nn.GroupNorm):
"""GroupNorm with float32 computation for stability"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.float()).type(x.dtype)
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (more efficient than LayerNorm)"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
class AdaLayerNorm(nn.Module):
"""Adaptive Layer Normalization conditioned on timestep (DiT-style)"""
def __init__(self, dim: int, time_embed_dim: int):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.proj = nn.Linear(time_embed_dim, dim * 2)
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
# t_emb: (B, time_embed_dim)
scale_shift = self.proj(t_emb)
scale, shift = scale_shift.chunk(2, dim=-1)
# Handle different input shapes
if x.dim() == 3: # (B, N, C)
scale = scale.unsqueeze(1)
shift = shift.unsqueeze(1)
elif x.dim() == 5: # (B, C, T, H, W)
scale = scale[:, :, None, None, None]
shift = shift[:, :, None, None, None]
return self.norm(x) * (1 + scale) + shift
class AdaLayerNormZero(nn.Module):
"""Adaptive Layer Normalization with zero-init (DiT-style)"""
def __init__(self, dim: int, time_embed_dim: int):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.proj = nn.Linear(time_embed_dim, dim * 6) # scale, shift, gate for both attn and ff
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> Tuple[torch.Tensor, ...]:
params = self.proj(t_emb)
return self.norm(x), params.chunk(6, dim=-1)
class Upsample3D(nn.Module):
"""3D Upsampling with convolution"""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv3d(channels, channels, 3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x, scale_factor=(1, 2, 2), mode='nearest')
return self.conv(x)
class Downsample3D(nn.Module):
"""3D Downsampling with convolution"""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv3d(channels, channels, 3, stride=(1, 2, 2), padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class ResBlock3D(nn.Module):
"""3D Residual block with time and context conditioning"""
def __init__(
self,
in_channels: int,
out_channels: int,
time_emb_dim: int,
dropout: float = 0.1,
):
super().__init__()
self.in_layers = nn.Sequential(
GroupNorm32(32, in_channels),
nn.SiLU(),
nn.Conv3d(in_channels, out_channels, 3, padding=1),
)
self.time_emb_proj = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, out_channels),
)
self.out_layers = nn.Sequential(
GroupNorm32(32, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_channels, out_channels, 3, padding=1),
)
if in_channels != out_channels:
self.skip_connection = nn.Conv3d(in_channels, out_channels, 1)
else:
self.skip_connection = nn.Identity()
def forward(
self,
x: torch.Tensor,
time_emb: torch.Tensor,
) -> torch.Tensor:
h = self.in_layers(x)
# Add time embedding
time_emb = self.time_emb_proj(time_emb)
h = h + time_emb[:, :, None, None, None]
h = self.out_layers(h)
return self.skip_connection(x) + h
class SpatialAttention(nn.Module):
"""Self-attention over spatial dimensions"""
def __init__(self, channels: int, num_heads: int = 8):
super().__init__()
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.norm = GroupNorm32(32, channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.proj = nn.Conv1d(channels, channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = x.shape
# Reshape to (B*T, C, H*W)
x_flat = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h * w)
# Normalize
x_norm = self.norm(x_flat.view(b * t, c, h, w)).view(b * t, c, h * w)
# QKV projection
qkv = self.qkv(x_norm)
q, k, v = qkv.chunk(3, dim=1)
# Reshape for multi-head attention
q = q.view(b * t, self.num_heads, self.head_dim, h * w).permute(0, 1, 3, 2)
k = k.view(b * t, self.num_heads, self.head_dim, h * w).permute(0, 1, 3, 2)
v = v.view(b * t, self.num_heads, self.head_dim, h * w).permute(0, 1, 3, 2)
# Attention
scale = self.head_dim ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.permute(0, 1, 3, 2).reshape(b * t, c, h * w)
out = self.proj(out)
out = out.view(b, t, c, h, w).permute(0, 2, 1, 3, 4)
return x + out
class CrossAttention(nn.Module):
"""Cross-attention for text conditioning"""
def __init__(
self,
query_dim: int,
context_dim: int,
num_heads: int = 8,
head_dim: int = 64,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
inner_dim = head_dim * num_heads
self.norm = GroupNorm32(32, query_dim)
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(0.1),
)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
) -> torch.Tensor:
b, c, t, h, w = x.shape
# Reshape to (B, T*H*W, C)
x_flat = x.permute(0, 2, 3, 4, 1).reshape(b, t * h * w, c)
# Normalize
x_norm = self.norm(x.view(b, c, -1)).permute(0, 2, 1)
# QKV
q = self.to_q(x_norm)
k = self.to_k(context)
v = self.to_v(context)
# Reshape for multi-head
q = q.view(b, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = k.view(b, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = v.view(b, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Attention
scale = self.head_dim ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).reshape(b, t * h * w, -1)
out = self.to_out(out)
out = out.view(b, t, h, w, c).permute(0, 4, 1, 2, 3)
return x + out
class TemporalAttention(nn.Module):
"""Self-attention over temporal dimension"""
def __init__(self, channels: int, num_heads: int = 8):
super().__init__()
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.norm = GroupNorm32(32, channels)
self.qkv = nn.Linear(channels, channels * 3)
self.proj = nn.Linear(channels, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = x.shape
# Reshape to (B*H*W, T, C)
x_flat = x.permute(0, 3, 4, 2, 1).reshape(b * h * w, t, c)
# Normalize
x_norm = self.norm(x.view(b, c, -1)).view(b, c, t, h, w)
x_norm = x_norm.permute(0, 3, 4, 2, 1).reshape(b * h * w, t, c)
# QKV
qkv = self.qkv(x_norm)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head
q = q.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = k.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = v.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Attention
scale = self.head_dim ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).reshape(b * h * w, t, c)
out = self.proj(out)
out = out.view(b, h, w, t, c).permute(0, 4, 3, 1, 2)
return x + out
# ============================================================================
# Transformer Components (DiT-style)
# ============================================================================
class MultiHeadAttention(nn.Module):
"""
Multi-head attention with optional flash attention and rotary embeddings.
Supports both self-attention and cross-attention.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
is_cross_attention: bool = False,
context_dim: Optional[int] = None,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.is_cross_attention = is_cross_attention
if is_cross_attention:
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
self.to_kv = nn.Linear(context_dim or dim, dim * 2, bias=qkv_bias)
else:
self.to_qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, N, C = x.shape
if self.is_cross_attention and context is not None:
q = self.to_q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
kv = self.to_kv(context).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
else:
qkv = self.to_qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.proj(out)
out = self.proj_drop(out)
return out
class FeedForward(nn.Module):
"""Feed-forward network with GELU activation"""
def __init__(
self,
dim: int,
hidden_dim: Optional[int] = None,
dropout: float = 0.0,
):
super().__init__()
hidden_dim = hidden_dim or dim * 4
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class DiTBlock(nn.Module):
"""
Diffusion Transformer Block (DiT-style).
Uses adaptive layer norm for timestep conditioning.
"""
def __init__(
self,
dim: int,
num_heads: int,
time_embed_dim: int,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
context_dim: Optional[int] = None,
):
super().__init__()
# Self-attention with adaptive norm
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = MultiHeadAttention(dim, num_heads, attn_drop=dropout, proj_drop=dropout)
# Cross-attention for text conditioning
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.cross_attn = MultiHeadAttention(
dim, num_heads,
attn_drop=dropout,
proj_drop=dropout,
is_cross_attention=True,
context_dim=context_dim,
)
# Feed-forward with adaptive norm
self.norm3 = nn.LayerNorm(dim, elementwise_affine=False)
self.ff = FeedForward(dim, int(dim * mlp_ratio), dropout)
# Adaptive parameters (DiT-style)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(time_embed_dim, dim * 9), # 3 params each for 3 blocks
)
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(
self,
x: torch.Tensor,
t_emb: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Get adaptive parameters
params = self.adaLN_modulation(t_emb)
(
scale1, shift1, gate1,
scale2, shift2, gate2,
scale3, shift3, gate3,
) = params.unsqueeze(1).chunk(9, dim=-1)
# Self-attention
x_norm = self.norm1(x) * (1 + scale1) + shift1
x = x + gate1 * self.attn(x_norm)
# Cross-attention
if context is not None:
x_norm = self.norm2(x) * (1 + scale2) + shift2
x = x + gate2 * self.cross_attn(x_norm, context)
# Feed-forward
x_norm = self.norm3(x) * (1 + scale3) + shift3
x = x + gate3 * self.ff(x_norm)
return x
class TemporalTransformerBlock(nn.Module):
"""
Transformer block specifically for temporal attention.
Processes video frames attending to other frames.
"""
def __init__(
self,
dim: int,
num_heads: int,
time_embed_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = MultiHeadAttention(dim, num_heads, attn_drop=dropout, proj_drop=dropout)
# Adaptive parameters
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(time_embed_dim, dim * 3),
)
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, T, C) temporal sequence
t_emb: (B, time_embed_dim) timestep embedding
"""
params = self.adaLN_modulation(t_emb)
scale, shift, gate = params.unsqueeze(1).chunk(3, dim=-1)
x_norm = self.norm(x) * (1 + scale) + shift
x = x + gate * self.attn(x_norm)
return x
class SpatioTemporalTransformer(nn.Module):
"""
Combined spatial and temporal transformer for video understanding.
First applies spatial attention within each frame, then temporal attention across frames.
"""
def __init__(
self,
dim: int,
num_heads: int,
time_embed_dim: int,
context_dim: int,
depth: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.spatial_blocks = nn.ModuleList([
DiTBlock(dim, num_heads, time_embed_dim, dropout=dropout, context_dim=context_dim)
for _ in range(depth)
])
self.temporal_blocks = nn.ModuleList([
TemporalTransformerBlock(dim, num_heads, time_embed_dim, dropout)
for _ in range(depth)
])
def forward(
self,
x: torch.Tensor, # (B, C, T, H, W)
t_emb: torch.Tensor, # (B, time_embed_dim)
context: torch.Tensor, # (B, seq_len, context_dim)
) -> torch.Tensor:
B, C, T, H, W = x.shape
# Spatial attention: process each frame
# Reshape to (B*T, H*W, C)
x_spatial = rearrange(x, 'b c t h w -> (b t) (h w) c')
t_emb_spatial = repeat(t_emb, 'b d -> (b t) d', t=T)
context_spatial = repeat(context, 'b n d -> (b t) n d', t=T)
for block in self.spatial_blocks:
x_spatial = block(x_spatial, t_emb_spatial, context_spatial)
# Reshape back: (B, T, H*W, C)
x_spatial = rearrange(x_spatial, '(b t) n c -> b t n c', b=B, t=T)
# Temporal attention: process each spatial location
# Reshape to (B*H*W, T, C)
x_temporal = rearrange(x_spatial, 'b t n c -> (b n) t c', n=H*W)
t_emb_temporal = repeat(t_emb, 'b d -> (b n) d', n=H*W)
for block in self.temporal_blocks:
x_temporal = block(x_temporal, t_emb_temporal)
# Reshape back to (B, C, T, H, W)
x_out = rearrange(x_temporal, '(b h w) t c -> b c t h w', b=B, h=H, w=W)
return x_out
class TransformerBlock3D(nn.Module):
"""
Enhanced Transformer block with spatial, temporal, and cross attention.
Uses DiT-style adaptive layer norm for better timestep conditioning.
"""
def __init__(
self,
channels: int,
context_dim: int,
time_embed_dim: int,
num_heads: int = 8,
transformer_depth: int = 1,
use_spatio_temporal: bool = True,
):
super().__init__()
self.use_spatio_temporal = use_spatio_temporal
if use_spatio_temporal:
# Use the new SpatioTemporalTransformer
self.transformer = SpatioTemporalTransformer(
dim=channels,
num_heads=num_heads,
time_embed_dim=time_embed_dim,
context_dim=context_dim,
depth=transformer_depth,
)
else:
# Fallback to simpler attention
self.spatial_attn = SpatialAttention(channels, num_heads)
self.temporal_attn = TemporalAttention(channels, num_heads)
self.cross_attn = CrossAttention(
query_dim=channels,
context_dim=context_dim,
num_heads=num_heads,
)
# Feed-forward (used in both cases)
self.ff = nn.Sequential(
GroupNorm32(32, channels),
nn.Conv3d(channels, channels * 4, 1),
nn.GELU(),
nn.Conv3d(channels * 4, channels, 1),
)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
t_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_spatio_temporal and t_emb is not None:
x = self.transformer(x, t_emb, context)
else:
x = self.spatial_attn(x)
x = self.temporal_attn(x)
x = self.cross_attn(x, context)
x = x + self.ff(x)
return x
class TemporalAttention(nn.Module):
"""Self-attention over temporal dimension (legacy, for backward compatibility)"""
def __init__(self, channels: int, num_heads: int = 8):
super().__init__()
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.norm = GroupNorm32(32, channels)
self.qkv = nn.Linear(channels, channels * 3)
self.proj = nn.Linear(channels, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = x.shape
# Reshape to (B*H*W, T, C)
x_flat = x.permute(0, 3, 4, 2, 1).reshape(b * h * w, t, c)
# Normalize
x_norm = self.norm(x.view(b, c, -1)).view(b, c, t, h, w)
x_norm = x_norm.permute(0, 3, 4, 2, 1).reshape(b * h * w, t, c)
# QKV
qkv = self.qkv(x_norm)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head
q = q.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = k.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = v.view(b * h * w, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Attention
scale = self.head_dim ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).reshape(b * h * w, t, c)
out = self.proj(out)
out = out.view(b, h, w, t, c).permute(0, 4, 3, 1, 2)
return x + out
class UNet3D(nn.Module):
"""
3D UNet for video diffusion with text conditioning.
Enhanced with DiT-style transformer blocks for better temporal modeling.
"""
def __init__(
self,
in_channels: int = 3,
model_channels: int = 128,
out_channels: int = 3,
num_res_blocks: int = 2,
attention_resolutions: Tuple[int, ...] = (8, 16),
channel_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_heads: int = 8,
context_dim: int = 512,
dropout: float = 0.1,
use_transformer: bool = True, # Use enhanced transformer blocks
transformer_depth: int = 1, # Depth of transformer blocks
use_gradient_checkpointing: bool = False, # Enable gradient checkpointing for memory
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.channel_mult = channel_mult
self.num_heads = num_heads
self.use_transformer = use_transformer
self.use_gradient_checkpointing = use_gradient_checkpointing
time_embed_dim = model_channels * 4
self.time_embed_dim = time_embed_dim
# Time embedding
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
# Input convolution
self.input_blocks = nn.ModuleList([
nn.Conv3d(in_channels, model_channels, 3, padding=1)
])
# Downsampling
ch = model_channels
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock3D(ch, mult * model_channels, time_embed_dim, dropout)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
TransformerBlock3D(
channels=ch,
context_dim=context_dim,
time_embed_dim=time_embed_dim,
num_heads=num_heads,
transformer_depth=transformer_depth,
use_spatio_temporal=use_transformer,
)
)
self.input_blocks.append(nn.ModuleList(layers))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
self.input_blocks.append(nn.ModuleList([Downsample3D(ch)]))
input_block_chans.append(ch)
ds *= 2
# Middle
self.middle_block = nn.ModuleList([
ResBlock3D(ch, ch, time_embed_dim, dropout),
TransformerBlock3D(
channels=ch,
context_dim=context_dim,
time_embed_dim=time_embed_dim,
num_heads=num_heads,
transformer_depth=transformer_depth,
use_spatio_temporal=use_transformer,
),
ResBlock3D(ch, ch, time_embed_dim, dropout),
])
# Upsampling
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock3D(ch + ich, mult * model_channels, time_embed_dim, dropout)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
TransformerBlock3D(
channels=ch,
context_dim=context_dim,
time_embed_dim=time_embed_dim,
num_heads=num_heads,
transformer_depth=transformer_depth,
use_spatio_temporal=use_transformer,
)
)
if level and i == num_res_blocks:
layers.append(Upsample3D(ch))
ds //= 2
self.output_blocks.append(nn.ModuleList(layers))
# Output
self.out = nn.Sequential(
GroupNorm32(32, ch),
nn.SiLU(),
nn.Conv3d(ch, out_channels, 3, padding=1),
)
def _checkpoint_forward(self, layer, h, t_emb, context=None):
"""Helper for gradient checkpointing"""
if isinstance(layer, ResBlock3D):
return layer(h, t_emb)
elif isinstance(layer, TransformerBlock3D):
return layer(h, context, t_emb)
elif isinstance(layer, (Downsample3D, Upsample3D)):
return layer(h)
return h
def forward(
self,
x: torch.Tensor, # (B, C, T, H, W)
timesteps: torch.Tensor, # (B,)
context: torch.Tensor, # (B, seq_len, context_dim)
) -> torch.Tensor:
"""
Forward pass
Args:
x: Noisy video tensor (B, C, T, H, W)
timesteps: Diffusion timesteps (B,)
context: Text embeddings (B, seq_len, context_dim)
Returns:
Predicted noise (B, C, T, H, W)
"""
from torch.utils.checkpoint import checkpoint
# Time embedding
t_emb = get_timestep_embedding(timesteps, self.model_channels)
t_emb = self.time_embed(t_emb)
# Downsampling path
hs = []
h = x
for module in self.input_blocks:
if isinstance(module, nn.Conv3d):
h = module(h)
elif isinstance(module, nn.ModuleList):
for layer in module:
if self.use_gradient_checkpointing and self.training:
h = checkpoint(self._checkpoint_forward, layer, h, t_emb, context, use_reentrant=False)
else:
h = self._checkpoint_forward(layer, h, t_emb, context)
hs.append(h)
# Middle
for layer in self.middle_block:
if self.use_gradient_checkpointing and self.training:
h = checkpoint(self._checkpoint_forward, layer, h, t_emb, context, use_reentrant=False)
else:
h = self._checkpoint_forward(layer, h, t_emb, context)
# Upsampling path
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
for layer in module:
if self.use_gradient_checkpointing and self.training:
h = checkpoint(self._checkpoint_forward, layer, h, t_emb, context, use_reentrant=False)
else:
h = self._checkpoint_forward(layer, h, t_emb, context)
return self.out(h)
def create_unet(config) -> UNet3D:
"""Create UNet model from config"""
return UNet3D(
in_channels=config.in_channels,
model_channels=config.model_channels,
out_channels=config.in_channels,
num_res_blocks=config.num_res_blocks,
attention_resolutions=config.attention_resolutions,
channel_mult=config.channel_mult,
num_heads=config.num_heads,
context_dim=config.context_dim,
use_transformer=getattr(config, 'use_transformer', True),
transformer_depth=getattr(config, 'transformer_depth', 1),
use_gradient_checkpointing=getattr(config, 'use_gradient_checkpointing', False),
)
if __name__ == "__main__":
# Test the enhanced model with transformer blocks
print("Testing UNet3D with DiT-style Transformer blocks...")
model = UNet3D(
in_channels=3,
model_channels=64,
channel_mult=(1, 2, 4),
attention_resolutions=(8, 16),
num_heads=4,
context_dim=256,
use_transformer=True,
transformer_depth=1,
)
# Test input
batch_size = 2
x = torch.randn(batch_size, 3, 16, 64, 64) # (B, C, T, H, W)
t = torch.randint(0, 1000, (batch_size,))
context = torch.randn(batch_size, 77, 256) # (B, seq_len, context_dim)
# Forward pass
out = model(x, t, context)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Test backward pass
loss = out.sum()
loss.backward()
print("Backward pass successful!")
# Test without transformer (legacy mode)
print("\nTesting UNet3D without transformer (legacy mode)...")
model_legacy = UNet3D(
in_channels=3,
model_channels=64,
channel_mult=(1, 2, 4),
attention_resolutions=(8, 16),
num_heads=4,
context_dim=256,
use_transformer=False,
)
out_legacy = model_legacy(x, t, context)
print(f"Legacy output shape: {out_legacy.shape}")
print(f"Legacy parameters: {sum(p.numel() for p in model_legacy.parameters()):,}")