asdf98 commited on
Commit
54a784b
·
verified ·
1 Parent(s): c5e4c7b

Upload iris/blocks.py

Browse files
Files changed (1) hide show
  1. iris/blocks.py +164 -0
iris/blocks.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core building blocks for IRIS: attention, FFN, cross-attention, embeddings.
3
+
4
+ Design principles:
5
+ - MQA (Multi-Query Attention) everywhere — shared K,V across heads
6
+ - UIB-FFN (Universal Inverted Bottleneck) — depthwise separable, expansion=2
7
+ - QK-RMSNorm for training stability (from SANA-Sprint)
8
+ - 2D RoPE for spatial position encoding
9
+ - Timestep addition (not AdaLN) — saves params (from HTH)
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import math
16
+ from typing import Optional
17
+
18
+
19
+ class RMSNorm(nn.Module):
20
+ def __init__(self, dim: int, eps: float = 1e-6):
21
+ super().__init__()
22
+ self.eps = eps
23
+ self.weight = nn.Parameter(torch.ones(dim))
24
+
25
+ def forward(self, x):
26
+ rms = torch.sqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
27
+ return (x.float() / rms * self.weight.float()).to(x.dtype)
28
+
29
+
30
+ class RotaryEmbedding2D(nn.Module):
31
+ def __init__(self, dim: int, max_size: int = 64):
32
+ super().__init__()
33
+ self.dim = dim
34
+ half_dim = dim // 4
35
+ inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_dim, dtype=torch.float32) / half_dim))
36
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
37
+
38
+ def _build_cache(self, H, W, device, dtype):
39
+ h_pos = torch.arange(H, device=device, dtype=torch.float32)
40
+ w_pos = torch.arange(W, device=device, dtype=torch.float32)
41
+ inv = self.inv_freq.to(device)
42
+ h_freqs = torch.outer(h_pos, inv)[:, None, :].expand(H, W, -1)
43
+ w_freqs = torch.outer(w_pos, inv)[None, :, :].expand(H, W, -1)
44
+ freqs = torch.cat([h_freqs, w_freqs], dim=-1).reshape(H * W, -1)
45
+ return freqs.cos().to(dtype), freqs.sin().to(dtype)
46
+
47
+ def forward(self, x, H, W):
48
+ N = H * W
49
+ cos_c, sin_c = self._build_cache(H, W, x.device, x.dtype)
50
+ if x.dim() == 4:
51
+ cos_c = cos_c[None, None, :N, :]
52
+ sin_c = sin_c[None, None, :N, :]
53
+ else:
54
+ cos_c = cos_c[None, :N, :]
55
+ sin_c = sin_c[None, :N, :]
56
+ d = cos_c.shape[-1]
57
+ x1, x2, xr = x[..., :d], x[..., d:2*d], x[..., 2*d:]
58
+ return torch.cat([x1*cos_c - x2*sin_c, x1*sin_c + x2*cos_c, xr], dim=-1)
59
+
60
+
61
+ class MultiQueryCrossAttention(nn.Module):
62
+ def __init__(self, dim, num_heads=4, qk_norm=True):
63
+ super().__init__()
64
+ assert dim % num_heads == 0
65
+ self.num_heads = num_heads
66
+ self.head_dim = dim // num_heads
67
+ self.q_proj = nn.Linear(dim, dim, bias=False)
68
+ self.k_proj = nn.Linear(dim, self.head_dim, bias=False)
69
+ self.v_proj = nn.Linear(dim, self.head_dim, bias=False)
70
+ self.out_proj = nn.Linear(dim, dim, bias=False)
71
+ self.q_norm = RMSNorm(self.head_dim) if qk_norm else nn.Identity()
72
+ self.k_norm = RMSNorm(self.head_dim) if qk_norm else nn.Identity()
73
+ self.norm = nn.LayerNorm(dim)
74
+
75
+ def forward(self, x, context):
76
+ B, N, D = x.shape
77
+ S = context.shape[1]
78
+ residual = x
79
+ x = self.norm(x)
80
+ q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
81
+ k = self.k_proj(context).view(B, S, 1, self.head_dim).transpose(1, 2)
82
+ v = self.v_proj(context).view(B, S, 1, self.head_dim).transpose(1, 2)
83
+ q, k = self.q_norm(q), self.k_norm(k)
84
+ k = k.expand(-1, self.num_heads, -1, -1)
85
+ v = v.expand(-1, self.num_heads, -1, -1)
86
+ attn = F.scaled_dot_product_attention(q, k, v, scale=1.0/math.sqrt(self.head_dim))
87
+ return residual + self.out_proj(attn.transpose(1, 2).reshape(B, N, D))
88
+
89
+
90
+ class MultiQuerySelfAttention(nn.Module):
91
+ def __init__(self, dim, num_heads=4, qk_norm=True):
92
+ super().__init__()
93
+ assert dim % num_heads == 0
94
+ self.num_heads = num_heads
95
+ self.head_dim = dim // num_heads
96
+ self.q_proj = nn.Linear(dim, dim, bias=False)
97
+ self.k_proj = nn.Linear(dim, self.head_dim, bias=False)
98
+ self.v_proj = nn.Linear(dim, self.head_dim, bias=False)
99
+ self.out_proj = nn.Linear(dim, dim, bias=False)
100
+ self.q_norm = RMSNorm(self.head_dim) if qk_norm else nn.Identity()
101
+ self.k_norm = RMSNorm(self.head_dim) if qk_norm else nn.Identity()
102
+ self.norm = nn.LayerNorm(dim)
103
+ self.rope = RotaryEmbedding2D(self.head_dim)
104
+
105
+ def forward(self, x, H, W):
106
+ B, N, D = x.shape
107
+ residual = x
108
+ x = self.norm(x)
109
+ q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
110
+ k = self.k_proj(x).view(B, N, 1, self.head_dim).transpose(1, 2)
111
+ v = self.v_proj(x).view(B, N, 1, self.head_dim).transpose(1, 2)
112
+ q, k = self.q_norm(q), self.k_norm(k)
113
+ q, k = self.rope(q, H, W), self.rope(k, H, W)
114
+ k = k.expand(-1, self.num_heads, -1, -1)
115
+ v = v.expand(-1, self.num_heads, -1, -1)
116
+ attn = F.scaled_dot_product_attention(q, k, v, scale=1.0/math.sqrt(self.head_dim))
117
+ return residual + self.out_proj(attn.transpose(1, 2).reshape(B, N, D))
118
+
119
+
120
+ class UIBFFN(nn.Module):
121
+ def __init__(self, dim, expansion=2, spatial_size=4):
122
+ super().__init__()
123
+ hidden = dim * expansion
124
+ self.norm = nn.LayerNorm(dim)
125
+ self.pw_up = nn.Linear(dim, hidden, bias=False)
126
+ self.gate = nn.Linear(dim, hidden, bias=False)
127
+ self.dw_conv = nn.Conv2d(hidden, hidden, 3, padding=1, groups=hidden, bias=True)
128
+ self.pw_down = nn.Linear(hidden, dim, bias=False)
129
+
130
+ def forward(self, x, H, W):
131
+ B, N, D = x.shape
132
+ residual = x
133
+ x = self.norm(x)
134
+ h = self.pw_up(x)
135
+ g = F.silu(self.gate(x))
136
+ h_2d = h.view(B, H, W, -1).permute(0, 3, 1, 2)
137
+ h = self.dw_conv(h_2d).permute(0, 2, 3, 1).reshape(B, N, -1)
138
+ return residual + self.pw_down(h * g)
139
+
140
+
141
+ class TimestepEmbedding(nn.Module):
142
+ def __init__(self, dim, max_period=10000):
143
+ super().__init__()
144
+ self.dim = dim
145
+ self.max_period = max_period
146
+ self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim))
147
+
148
+ def forward(self, t):
149
+ half = self.dim // 2
150
+ freqs = torch.exp(-math.log(self.max_period) * torch.arange(half, device=t.device, dtype=torch.float32) / half)
151
+ args = t[:, None].float() * freqs[None, :]
152
+ emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
153
+ if self.dim % 2:
154
+ emb = F.pad(emb, (0, 1))
155
+ return self.mlp(emb.to(t.dtype))
156
+
157
+
158
+ class IterationEmbedding(nn.Module):
159
+ def __init__(self, dim, max_iterations=8):
160
+ super().__init__()
161
+ self.embed = nn.Embedding(max_iterations, dim)
162
+
163
+ def forward(self, iter_idx, batch_size, device):
164
+ return self.embed(torch.full((batch_size,), iter_idx, device=device, dtype=torch.long))