flow-matching / src /stage2 /velocity_net.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import math
from typing import Optional
import torch
import torch.nn as nn
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embedding for timestep inputs."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
if t.ndim == 0:
t = t.unsqueeze(0)
if not torch.is_floating_point(t):
t = t.float()
t = t * 1000.0
half_dim = self.dim // 2
emb_scale = math.log(10000) / max(half_dim - 1, 1)
emb = torch.exp(
torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb_scale
)
emb = t.unsqueeze(1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class MultiTokenFusion(nn.Module):
"""Project each modality to a shared hidden space and fuse across modalities."""
def __init__(
self,
modality_dims: list[int],
hidden_dim: int = 256,
dropout: float = 0.1,
modality_dropout: float = 0.0,
):
super().__init__()
self.modality_dims = modality_dims
self.n_modalities = len(modality_dims)
self.hidden_dim = hidden_dim
self.modality_dropout = modality_dropout
self.projectors = nn.ModuleList(
[
nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
for dim in modality_dims
]
)
self.modality_emb = nn.Parameter(torch.randn(self.n_modalities, hidden_dim) * 0.02)
self.output_proj = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
)
def forward(self, modality_features: list[torch.Tensor]) -> torch.Tensor:
if len(modality_features) != self.n_modalities:
raise ValueError(
f"Expected {self.n_modalities} modalities, got {len(modality_features)}."
)
projected = []
for i, (feat, proj) in enumerate(zip(modality_features, self.projectors)):
h = proj(feat)
h = h + self.modality_emb[i]
projected.append(h)
if self.training and self.modality_dropout > 0:
keep_mask = (
torch.rand(
projected[0].shape[0],
projected[0].shape[1],
self.n_modalities,
device=projected[0].device,
)
> self.modality_dropout
)
all_dropped = keep_mask.sum(dim=2, keepdim=True) == 0
keep_mask[:, :, 0:1] = torch.max(keep_mask[:, :, 0:1], all_dropped)
scale = 1.0 / max(1.0 - self.modality_dropout, 1e-6)
for i in range(self.n_modalities):
projected[i] = projected[i] * keep_mask[:, :, i : i + 1] * scale
x = torch.stack(projected, dim=0).mean(dim=0)
return self.output_proj(x)
class SimpleFiLMBlock(nn.Module):
"""Residual FiLM block with feed-forward and context cross-attention."""
def __init__(
self,
dim: int,
time_dim: int,
context_dim: int,
n_heads: int = 8,
dropout: float = 0.1,
):
super().__init__()
self.film = nn.Linear(time_dim, dim * 2)
self.norm1 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim),
nn.Dropout(dropout),
)
self.norm_q = nn.LayerNorm(dim)
self.norm_kv = nn.LayerNorm(context_dim)
self.cross_attn = nn.MultiheadAttention(
dim,
n_heads,
dropout=dropout,
batch_first=True,
kdim=context_dim,
vdim=context_dim,
)
def forward(
self,
x: torch.Tensor,
t_emb: torch.Tensor,
context: torch.Tensor,
) -> torch.Tensor:
scale_shift = self.film(t_emb)
scale, shift = scale_shift.chunk(2, dim=-1)
h = self.norm1(x) * (1 + scale) + shift
x = x + self.ffn(h)
q = self.norm_q(x).unsqueeze(1)
kv = self.norm_kv(context)
attn_out, _ = self.cross_attn(q, kv, kv, need_weights=False)
x = x + attn_out.squeeze(1)
return x
class VelocityNet(nn.Module):
"""DiT-style velocity estimator with late-fusion context conditioning."""
def __init__(
self,
output_dim: int,
hidden_dim: int = 256,
modality_dims: Optional[list[int]] = None,
n_blocks: int = 4,
n_heads: int = 8,
dropout: float = 0.1,
modality_dropout: float = 0.0,
max_seq_len: int = 2048,
temporal_attn_layers: int = 2,
):
super().__init__()
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.modality_dims = modality_dims or [output_dim]
self.max_seq_len = max_seq_len
self.fusion_block = MultiTokenFusion(
modality_dims=self.modality_dims,
hidden_dim=hidden_dim,
dropout=dropout,
modality_dropout=modality_dropout,
)
self.context_pos_emb = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim) * 0.02)
if temporal_attn_layers > 0:
self.temporal_attn = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=n_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
activation="gelu",
batch_first=True,
norm_first=True,
),
num_layers=temporal_attn_layers,
)
else:
self.temporal_attn = nn.Identity()
self.temporal_norm = nn.LayerNorm(hidden_dim)
self.input_proj = nn.Sequential(
nn.Linear(output_dim, hidden_dim),
nn.GELU(),
)
self.time_emb = SinusoidalPosEmb(hidden_dim)
self.time_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
)
self.blocks = nn.ModuleList(
[
SimpleFiLMBlock(
dim=hidden_dim,
time_dim=hidden_dim,
context_dim=hidden_dim,
n_heads=n_heads,
dropout=dropout,
)
for _ in range(n_blocks)
]
)
self.final_norm = nn.LayerNorm(hidden_dim)
self.output_layer = nn.Linear(hidden_dim, output_dim)
nn.init.constant_(self.output_layer.weight, 0)
nn.init.constant_(self.output_layer.bias, 0)
def encode_context(self, cond: torch.Tensor) -> torch.Tensor:
"""Encode context tensor from (B, T, total_dim) to (B, T, hidden_dim)."""
if cond.ndim != 3:
raise ValueError(f"Expected cond with shape (B, T, D), got {tuple(cond.shape)}")
B, T, D = cond.shape
if T > self.max_seq_len:
raise ValueError(
f"Sequence length {T} exceeds max_seq_len={self.max_seq_len}. "
"Increase max_seq_len in stage2.velocity_net config."
)
splits = []
offset = 0
for dim in self.modality_dims:
splits.append(cond[:, :, offset : offset + dim])
offset += dim
if offset != D:
raise ValueError(
f"Context dim mismatch: expected sum(modality_dims)={offset}, got {D}."
)
context = self.fusion_block(splits)
context = context + self.context_pos_emb[:, :T, :]
context = self.temporal_attn(context)
context = self.temporal_norm(context)
return context
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
cond: Optional[torch.Tensor] = None,
pre_encoded_context: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if t.ndim == 0:
t = t.expand(x.shape[0])
if pre_encoded_context is not None:
context_encoded = pre_encoded_context
elif cond is not None:
context_encoded = self.encode_context(cond)
else:
context_encoded = torch.zeros(
x.shape[0],
1,
self.hidden_dim,
device=x.device,
dtype=x.dtype,
)
t_emb = self.time_mlp(self.time_emb(t))
h = self.input_proj(x)
for block in self.blocks:
h = block(h, t_emb, context_encoded)
h = self.final_norm(h)
return self.output_layer(h)