asdf98 commited on
Commit
f307a14
·
verified ·
1 Parent(s): 54a784b

Upload iris/core.py

Browse files
Files changed (1) hide show
  1. iris/core.py +66 -0
iris/core.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """IRIS Refinement Core: Weight-shared denoising backbone."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ from .pde_ssm import PDESSMBlock
8
+ from .blocks import (MultiQueryCrossAttention, MultiQuerySelfAttention, UIBFFN, TimestepEmbedding, IterationEmbedding)
9
+
10
+
11
+ class IRISBlock(nn.Module):
12
+ def __init__(self, dim, num_heads=4, spatial_size=4, use_attn=False, ffn_expansion=2):
13
+ super().__init__()
14
+ self.cross_attn = MultiQueryCrossAttention(dim, num_heads=num_heads)
15
+ if use_attn:
16
+ self.spatial_mixer = MultiQuerySelfAttention(dim, num_heads=num_heads)
17
+ else:
18
+ self.spatial_mixer = PDESSMBlock(dim, spatial_size=spatial_size)
19
+ self.ffn = UIBFFN(dim, expansion=ffn_expansion, spatial_size=spatial_size)
20
+ self.use_attn = use_attn
21
+
22
+ def forward(self, x, context, H, W):
23
+ x = self.cross_attn(x, context)
24
+ x = self.spatial_mixer(x, H, W)
25
+ x = self.ffn(x, H, W)
26
+ return x
27
+
28
+
29
+ class RefinementCore(nn.Module):
30
+ """6 blocks (5 PDE-SSM + 1 Self-Attention), applied R times with same weights."""
31
+
32
+ def __init__(self, dim=512, num_blocks=6, num_heads=4, spatial_size=4, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True):
33
+ super().__init__()
34
+ self.dim = dim
35
+ self.num_blocks = num_blocks
36
+ self.max_iterations = max_iterations
37
+ self.gradient_checkpointing = gradient_checkpointing
38
+
39
+ self.blocks = nn.ModuleList()
40
+ for i in range(num_blocks):
41
+ use_attn = (i == num_blocks - 1)
42
+ self.blocks.append(IRISBlock(dim=dim, num_heads=num_heads, spatial_size=spatial_size, use_attn=use_attn, ffn_expansion=ffn_expansion))
43
+
44
+ self.timestep_embed = TimestepEmbedding(dim)
45
+ self.iter_embed = IterationEmbedding(dim, max_iterations=max_iterations)
46
+ self.final_norm = nn.LayerNorm(dim)
47
+
48
+ def _single_iteration(self, x, context, t_emb, iter_emb, H, W):
49
+ cond = (t_emb + iter_emb).unsqueeze(1)
50
+ x = x + cond
51
+ for block in self.blocks:
52
+ x = block(x, context, H, W)
53
+ return x
54
+
55
+ def forward(self, x, context, t, H, W, num_iterations=4):
56
+ B = x.shape[0]
57
+ t_emb = self.timestep_embed(t)
58
+
59
+ for r in range(num_iterations):
60
+ iter_emb = self.iter_embed(r, B, x.device)
61
+ if self.gradient_checkpointing and self.training:
62
+ x = checkpoint(self._single_iteration, x, context, t_emb, iter_emb, H, W, use_reentrant=False)
63
+ else:
64
+ x = self._single_iteration(x, context, t_emb, iter_emb, H, W)
65
+
66
+ return self.final_norm(x)