File size: 2,656 Bytes
f307a14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""IRIS Refinement Core: Weight-shared denoising backbone."""

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from .pde_ssm import PDESSMBlock
from .blocks import (MultiQueryCrossAttention, MultiQuerySelfAttention, UIBFFN, TimestepEmbedding, IterationEmbedding)


class IRISBlock(nn.Module):
    def __init__(self, dim, num_heads=4, spatial_size=4, use_attn=False, ffn_expansion=2):
        super().__init__()
        self.cross_attn = MultiQueryCrossAttention(dim, num_heads=num_heads)
        if use_attn:
            self.spatial_mixer = MultiQuerySelfAttention(dim, num_heads=num_heads)
        else:
            self.spatial_mixer = PDESSMBlock(dim, spatial_size=spatial_size)
        self.ffn = UIBFFN(dim, expansion=ffn_expansion, spatial_size=spatial_size)
        self.use_attn = use_attn

    def forward(self, x, context, H, W):
        x = self.cross_attn(x, context)
        x = self.spatial_mixer(x, H, W)
        x = self.ffn(x, H, W)
        return x


class RefinementCore(nn.Module):
    """6 blocks (5 PDE-SSM + 1 Self-Attention), applied R times with same weights."""

    def __init__(self, dim=512, num_blocks=6, num_heads=4, spatial_size=4, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True):
        super().__init__()
        self.dim = dim
        self.num_blocks = num_blocks
        self.max_iterations = max_iterations
        self.gradient_checkpointing = gradient_checkpointing

        self.blocks = nn.ModuleList()
        for i in range(num_blocks):
            use_attn = (i == num_blocks - 1)
            self.blocks.append(IRISBlock(dim=dim, num_heads=num_heads, spatial_size=spatial_size, use_attn=use_attn, ffn_expansion=ffn_expansion))

        self.timestep_embed = TimestepEmbedding(dim)
        self.iter_embed = IterationEmbedding(dim, max_iterations=max_iterations)
        self.final_norm = nn.LayerNorm(dim)

    def _single_iteration(self, x, context, t_emb, iter_emb, H, W):
        cond = (t_emb + iter_emb).unsqueeze(1)
        x = x + cond
        for block in self.blocks:
            x = block(x, context, H, W)
        return x

    def forward(self, x, context, t, H, W, num_iterations=4):
        B = x.shape[0]
        t_emb = self.timestep_embed(t)

        for r in range(num_iterations):
            iter_emb = self.iter_embed(r, B, x.device)
            if self.gradient_checkpointing and self.training:
                x = checkpoint(self._single_iteration, x, context, t_emb, iter_emb, H, W, use_reentrant=False)
            else:
                x = self._single_iteration(x, context, t_emb, iter_emb, H, W)

        return self.final_norm(x)