| """IRIS Refinement Core: Weight-shared denoising backbone.""" |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.checkpoint import checkpoint |
|
|
| from .pde_ssm import PDESSMBlock |
| from .blocks import (MultiQueryCrossAttention, MultiQuerySelfAttention, UIBFFN, TimestepEmbedding, IterationEmbedding) |
|
|
|
|
| class IRISBlock(nn.Module): |
| def __init__(self, dim, num_heads=4, spatial_size=4, use_attn=False, ffn_expansion=2): |
| super().__init__() |
| self.cross_attn = MultiQueryCrossAttention(dim, num_heads=num_heads) |
| if use_attn: |
| self.spatial_mixer = MultiQuerySelfAttention(dim, num_heads=num_heads) |
| else: |
| self.spatial_mixer = PDESSMBlock(dim, spatial_size=spatial_size) |
| self.ffn = UIBFFN(dim, expansion=ffn_expansion, spatial_size=spatial_size) |
| self.use_attn = use_attn |
|
|
| def forward(self, x, context, H, W): |
| x = self.cross_attn(x, context) |
| x = self.spatial_mixer(x, H, W) |
| x = self.ffn(x, H, W) |
| return x |
|
|
|
|
| class RefinementCore(nn.Module): |
| """6 blocks (5 PDE-SSM + 1 Self-Attention), applied R times with same weights.""" |
|
|
| def __init__(self, dim=512, num_blocks=6, num_heads=4, spatial_size=4, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True): |
| super().__init__() |
| self.dim = dim |
| self.num_blocks = num_blocks |
| self.max_iterations = max_iterations |
| self.gradient_checkpointing = gradient_checkpointing |
|
|
| self.blocks = nn.ModuleList() |
| for i in range(num_blocks): |
| use_attn = (i == num_blocks - 1) |
| self.blocks.append(IRISBlock(dim=dim, num_heads=num_heads, spatial_size=spatial_size, use_attn=use_attn, ffn_expansion=ffn_expansion)) |
|
|
| self.timestep_embed = TimestepEmbedding(dim) |
| self.iter_embed = IterationEmbedding(dim, max_iterations=max_iterations) |
| self.final_norm = nn.LayerNorm(dim) |
|
|
| def _single_iteration(self, x, context, t_emb, iter_emb, H, W): |
| cond = (t_emb + iter_emb).unsqueeze(1) |
| x = x + cond |
| for block in self.blocks: |
| x = block(x, context, H, W) |
| return x |
|
|
| def forward(self, x, context, t, H, W, num_iterations=4): |
| B = x.shape[0] |
| t_emb = self.timestep_embed(t) |
|
|
| for r in range(num_iterations): |
| iter_emb = self.iter_embed(r, B, x.device) |
| if self.gradient_checkpointing and self.training: |
| x = checkpoint(self._single_iteration, x, context, t_emb, iter_emb, H, W, use_reentrant=False) |
| else: |
| x = self._single_iteration(x, context, t_emb, iter_emb, H, W) |
|
|
| return self.final_norm(x) |
|
|