iris-image-gen / iris /core.py
asdf98's picture
Upload iris/core.py
f307a14 verified
"""IRIS Refinement Core: Weight-shared denoising backbone."""
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .pde_ssm import PDESSMBlock
from .blocks import (MultiQueryCrossAttention, MultiQuerySelfAttention, UIBFFN, TimestepEmbedding, IterationEmbedding)
class IRISBlock(nn.Module):
def __init__(self, dim, num_heads=4, spatial_size=4, use_attn=False, ffn_expansion=2):
super().__init__()
self.cross_attn = MultiQueryCrossAttention(dim, num_heads=num_heads)
if use_attn:
self.spatial_mixer = MultiQuerySelfAttention(dim, num_heads=num_heads)
else:
self.spatial_mixer = PDESSMBlock(dim, spatial_size=spatial_size)
self.ffn = UIBFFN(dim, expansion=ffn_expansion, spatial_size=spatial_size)
self.use_attn = use_attn
def forward(self, x, context, H, W):
x = self.cross_attn(x, context)
x = self.spatial_mixer(x, H, W)
x = self.ffn(x, H, W)
return x
class RefinementCore(nn.Module):
"""6 blocks (5 PDE-SSM + 1 Self-Attention), applied R times with same weights."""
def __init__(self, dim=512, num_blocks=6, num_heads=4, spatial_size=4, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True):
super().__init__()
self.dim = dim
self.num_blocks = num_blocks
self.max_iterations = max_iterations
self.gradient_checkpointing = gradient_checkpointing
self.blocks = nn.ModuleList()
for i in range(num_blocks):
use_attn = (i == num_blocks - 1)
self.blocks.append(IRISBlock(dim=dim, num_heads=num_heads, spatial_size=spatial_size, use_attn=use_attn, ffn_expansion=ffn_expansion))
self.timestep_embed = TimestepEmbedding(dim)
self.iter_embed = IterationEmbedding(dim, max_iterations=max_iterations)
self.final_norm = nn.LayerNorm(dim)
def _single_iteration(self, x, context, t_emb, iter_emb, H, W):
cond = (t_emb + iter_emb).unsqueeze(1)
x = x + cond
for block in self.blocks:
x = block(x, context, H, W)
return x
def forward(self, x, context, t, H, W, num_iterations=4):
B = x.shape[0]
t_emb = self.timestep_embed(t)
for r in range(num_iterations):
iter_emb = self.iter_embed(r, B, x.device)
if self.gradient_checkpointing and self.training:
x = checkpoint(self._single_iteration, x, context, t_emb, iter_emb, H, W, use_reentrant=False)
else:
x = self._single_iteration(x, context, t_emb, iter_emb, H, W)
return self.final_norm(x)