"""IRIS Refinement Core: Weight-shared denoising backbone.""" import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from .pde_ssm import PDESSMBlock from .blocks import (MultiQueryCrossAttention, MultiQuerySelfAttention, UIBFFN, TimestepEmbedding, IterationEmbedding) class IRISBlock(nn.Module): def __init__(self, dim, num_heads=4, spatial_size=4, use_attn=False, ffn_expansion=2): super().__init__() self.cross_attn = MultiQueryCrossAttention(dim, num_heads=num_heads) if use_attn: self.spatial_mixer = MultiQuerySelfAttention(dim, num_heads=num_heads) else: self.spatial_mixer = PDESSMBlock(dim, spatial_size=spatial_size) self.ffn = UIBFFN(dim, expansion=ffn_expansion, spatial_size=spatial_size) self.use_attn = use_attn def forward(self, x, context, H, W): x = self.cross_attn(x, context) x = self.spatial_mixer(x, H, W) x = self.ffn(x, H, W) return x class RefinementCore(nn.Module): """6 blocks (5 PDE-SSM + 1 Self-Attention), applied R times with same weights.""" def __init__(self, dim=512, num_blocks=6, num_heads=4, spatial_size=4, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True): super().__init__() self.dim = dim self.num_blocks = num_blocks self.max_iterations = max_iterations self.gradient_checkpointing = gradient_checkpointing self.blocks = nn.ModuleList() for i in range(num_blocks): use_attn = (i == num_blocks - 1) self.blocks.append(IRISBlock(dim=dim, num_heads=num_heads, spatial_size=spatial_size, use_attn=use_attn, ffn_expansion=ffn_expansion)) self.timestep_embed = TimestepEmbedding(dim) self.iter_embed = IterationEmbedding(dim, max_iterations=max_iterations) self.final_norm = nn.LayerNorm(dim) def _single_iteration(self, x, context, t_emb, iter_emb, H, W): cond = (t_emb + iter_emb).unsqueeze(1) x = x + cond for block in self.blocks: x = block(x, context, H, W) return x def forward(self, x, context, t, H, W, num_iterations=4): B = x.shape[0] t_emb = self.timestep_embed(t) for r in range(num_iterations): iter_emb = self.iter_embed(r, B, x.device) if self.gradient_checkpointing and self.training: x = checkpoint(self._single_iteration, x, context, t_emb, iter_emb, H, W, use_reentrant=False) else: x = self._single_iteration(x, context, t_emb, iter_emb, H, W) return self.final_norm(x)