File size: 7,282 Bytes
c6e7340
359afd9
 
 
c6e7340
359afd9
c6e7340
359afd9
c6e7340
 
 
359afd9
 
 
 
c6e7340
359afd9
c6e7340
 
359afd9
 
 
c6e7340
359afd9
 
 
c6e7340
359afd9
 
 
 
 
 
 
 
 
 
 
 
 
c6e7340
 
359afd9
 
c6e7340
 
359afd9
 
 
 
 
 
c6e7340
359afd9
 
 
 
 
 
c6e7340
359afd9
c6e7340
359afd9
 
c6e7340
359afd9
 
 
c6e7340
359afd9
 
 
c6e7340
359afd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e7340
 
359afd9
c6e7340
359afd9
 
c6e7340
 
359afd9
c6e7340
359afd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e7340
359afd9
c6e7340
359afd9
c6e7340
359afd9
 
c6e7340
 
359afd9
 
 
c6e7340
 
359afd9
c6e7340
 
 
 
 
359afd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e7340
 
359afd9
 
c6e7340
 
359afd9
 
 
 
 
 
 
 
 
 
c6e7340
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
LuminaRS -- Lightweight Latent Recursive Diffusion for Art/Illustration.
~100M params, designed for 2-4GB VRAM mobile deployment.
Uses pretrained VAE + CLIP text encoder (both frozen).
"""
import math, torch, torch.nn as nn, torch.nn.functional as F

# ── helpers ──────────────────────────────────────────────────────────────

def timestep_embedding(t, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) *
        torch.arange(half, device=t.device, dtype=torch.float32) / half)
    args = t[:, None].float() * freqs[None]
    emb = torch.cat([torch.cos(args), torch.sin(args)], -1)
    if dim % 2:
        emb = F.pad(emb, (0, 1))
    return emb

class DropPath(nn.Module):
    def __init__(self, p=0.0):
        super().__init__(); self.p = p
    def forward(self, x):
        if self.p == 0. or not self.training: return x
        mask = torch.zeros(x.shape[0],1,1,1, device=x.device).bernoulli_(1-self.p)
        return x * mask / (1 - self.p)

class AdaLN(nn.Module):
    """Adaptive LayerNorm: shift+scale from time embedding."""
    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm = nn.GroupNorm(1, dim)
        self.s = nn.Linear(cond_dim, dim)
        self.b = nn.Linear(cond_dim, dim)
    def forward(self, x, c):
        """x: (B,C,H,W)  c: (B,cond_dim)"""
        h = self.norm(x)
        s = self.s(c)[:,:,None,None]
        b = self.b(c)[:,:,None,None]
        return h * (1+s) + b

class MQAttention(nn.Module):
    """Multi-Query Attention: one K,V shared across heads."""
    def __init__(self, dim, n_heads=4):
        super().__init__()
        assert dim % n_heads == 0
        self.nh = n_heads; self.dh = dim // n_heads; self.scale = self.dh**-0.5
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.o = nn.Linear(dim, dim)
    def forward(self, x, ctx=None):
        B, L, C = x.shape
        ctx = x if ctx is None else ctx
        q = self.q(x).view(B, L, self.nh, self.dh).transpose(1,2)
        k = self.k(ctx).view(B,-1, self.nh, self.dh).transpose(1,2)
        v = self.v(ctx).view(B,-1, self.nh, self.dh).transpose(1,2)
        a = (q @ k.transpose(-2,-1)) * self.scale
        return self.o((a.softmax(-1) @ v).transpose(1,2).reshape(B,L,C))

# ── ConvNeXt + cross-attn block ─────────────────────────────────────────

class ResBlock(nn.Module):
    def __init__(self, dim, cond_dim, text_dim=None, drop_path=0.0):
        super().__init__()
        self.adaln = AdaLN(dim, cond_dim)
        self.dw = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
        self.pw1 = nn.Linear(dim, dim*2)
        self.act = nn.GELU()
        self.pw2 = nn.Linear(dim*2, dim)
        self.drop = DropPath(drop_path)
        self.xattn = None
        if text_dim is not None:
            self.xattn = MQAttention(dim)
            self.tproj = nn.Linear(text_dim, dim)
            self.tnorm = nn.GroupNorm(1, dim)
    def forward(self, x, cond, text=None):
        r = x
        x = self.adaln(x, cond)
        x = self.dw(x)
        x = x.permute(0,2,3,1)
        x = self.pw2(self.act(self.pw1(x)))
        x = x.permute(0,3,1,2)
        x = r + self.drop(x)
        if self.xattn is not None and text is not None:
            B,C,H,W = x.shape
            xf = x.view(B,C,H*W).transpose(1,2)
            xf = xf + self.xattn(xf, self.tproj(text))
            x = xf.transpose(1,2).view(B,C,H,W)
        return x

# ── Down / Up ────────────────────────────────────────────────────────────

class Down(nn.Module):
    def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n)
        ])
        if cin != cout:
            self.skip_conv = nn.Conv2d(cin, cout, 1)
        else:
            self.skip_conv = None
        self.down = nn.Conv2d(cout, cout, 3, stride=2, padding=1)
    def forward(self, x, cond, text=None):
        for b in self.blocks:
            x = b(x, cond, text)
        skip = x
        if self.skip_conv is not None:
            skip = self.skip_conv(skip)
        return self.down(x), skip

class Up(nn.Module):
    def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0):
        super().__init__()
        self.up = nn.ConvTranspose2d(cin, cin, 2, stride=2)
        self.blocks = nn.ModuleList([
            ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n)
        ])
        self.skip_conv = nn.Conv2d(cin, cout, 1) if cin != cout else nn.Identity()
    def forward(self, x, skip, cond, text=None):
        x = self.up(x)
        x = x + skip
        for b in self.blocks:
            x = b(x, cond, text)
            x = self.skip_conv(x) if hasattr(self, 'skip_conv') and not isinstance(self.skip_conv, nn.Identity) and b == self.blocks[0] else x
        return x

# ── Main Model ───────────────────────────────────────────────────────────

class LuminaRS(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        chs = cfg.channels  # e.g. (64, 128, 256, 384)
        t_dim = cfg.t_embed_dim
        cond_dim = chs[-1] * 4

        self.time_mlp = nn.Sequential(
            nn.Linear(t_dim, cond_dim), nn.SiLU(), nn.Linear(cond_dim, cond_dim))

        self.text_proj = nn.Linear(cfg.text_embed_dim, cfg.text_proj_dim)

        self.in_conv = nn.Conv2d(cfg.latent_dim, chs[0], 3, padding=1)

        # Encoder
        self.enc = nn.ModuleList()
        for i in range(len(chs)-1):
            self.enc.append(Down(chs[i], chs[i+1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path))

        # Bottleneck
        self.bot = nn.ModuleList([
            ResBlock(chs[-1], cond_dim, cfg.text_proj_dim, dp=cfg.drop_path)
            for _ in range(cfg.n_bottleneck)])

        # Decoder
        self.dec = nn.ModuleList()
        for i in range(len(chs)-1, 0, -1):
            self.dec.append(Up(chs[i], chs[i-1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path))

        self.out_conv = nn.Conv2d(chs[0], cfg.latent_dim, 1)
        self.n_recurse = cfg.n_recurse

    def forward(self, z, text_emb, t):
        B = z.shape[0]
        t_emb = self.time_mlp(timestep_embedding(t, self.cfg.t_embed_dim))
        text = self.text_proj(text_emb)
        x = self.in_conv(z)
        for _ in range(self.n_recurse):
            h = x; skips = []
            for down in self.enc:
                h, sk = down(h, t_emb, text)
                skips.append(sk)
            for blk in self.bot:
                h = blk(h, t_emb, text)
            for up in self.dec:
                sk = skips.pop()
                h = up(h, sk, t_emb, text)
            x = x + h
        return self.out_conv(x)