Upload iris/core.py
Browse files- iris/core.py +66 -0
iris/core.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""IRIS Refinement Core: Weight-shared denoising backbone."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.checkpoint import checkpoint
|
| 6 |
+
|
| 7 |
+
from .pde_ssm import PDESSMBlock
|
| 8 |
+
from .blocks import (MultiQueryCrossAttention, MultiQuerySelfAttention, UIBFFN, TimestepEmbedding, IterationEmbedding)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class IRISBlock(nn.Module):
|
| 12 |
+
def __init__(self, dim, num_heads=4, spatial_size=4, use_attn=False, ffn_expansion=2):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.cross_attn = MultiQueryCrossAttention(dim, num_heads=num_heads)
|
| 15 |
+
if use_attn:
|
| 16 |
+
self.spatial_mixer = MultiQuerySelfAttention(dim, num_heads=num_heads)
|
| 17 |
+
else:
|
| 18 |
+
self.spatial_mixer = PDESSMBlock(dim, spatial_size=spatial_size)
|
| 19 |
+
self.ffn = UIBFFN(dim, expansion=ffn_expansion, spatial_size=spatial_size)
|
| 20 |
+
self.use_attn = use_attn
|
| 21 |
+
|
| 22 |
+
def forward(self, x, context, H, W):
|
| 23 |
+
x = self.cross_attn(x, context)
|
| 24 |
+
x = self.spatial_mixer(x, H, W)
|
| 25 |
+
x = self.ffn(x, H, W)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RefinementCore(nn.Module):
|
| 30 |
+
"""6 blocks (5 PDE-SSM + 1 Self-Attention), applied R times with same weights."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, dim=512, num_blocks=6, num_heads=4, spatial_size=4, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.dim = dim
|
| 35 |
+
self.num_blocks = num_blocks
|
| 36 |
+
self.max_iterations = max_iterations
|
| 37 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 38 |
+
|
| 39 |
+
self.blocks = nn.ModuleList()
|
| 40 |
+
for i in range(num_blocks):
|
| 41 |
+
use_attn = (i == num_blocks - 1)
|
| 42 |
+
self.blocks.append(IRISBlock(dim=dim, num_heads=num_heads, spatial_size=spatial_size, use_attn=use_attn, ffn_expansion=ffn_expansion))
|
| 43 |
+
|
| 44 |
+
self.timestep_embed = TimestepEmbedding(dim)
|
| 45 |
+
self.iter_embed = IterationEmbedding(dim, max_iterations=max_iterations)
|
| 46 |
+
self.final_norm = nn.LayerNorm(dim)
|
| 47 |
+
|
| 48 |
+
def _single_iteration(self, x, context, t_emb, iter_emb, H, W):
|
| 49 |
+
cond = (t_emb + iter_emb).unsqueeze(1)
|
| 50 |
+
x = x + cond
|
| 51 |
+
for block in self.blocks:
|
| 52 |
+
x = block(x, context, H, W)
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
def forward(self, x, context, t, H, W, num_iterations=4):
|
| 56 |
+
B = x.shape[0]
|
| 57 |
+
t_emb = self.timestep_embed(t)
|
| 58 |
+
|
| 59 |
+
for r in range(num_iterations):
|
| 60 |
+
iter_emb = self.iter_embed(r, B, x.device)
|
| 61 |
+
if self.gradient_checkpointing and self.training:
|
| 62 |
+
x = checkpoint(self._single_iteration, x, context, t_emb, iter_emb, H, W, use_reentrant=False)
|
| 63 |
+
else:
|
| 64 |
+
x = self._single_iteration(x, context, t_emb, iter_emb, H, W)
|
| 65 |
+
|
| 66 |
+
return self.final_norm(x)
|