asdf98 commited on
Commit
e80ddea
·
verified ·
1 Parent(s): 8801354

Add microforge/planner.py

Browse files
Files changed (1) hide show
  1. microforge/planner.py +270 -0
microforge/planner.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Recurrent Latent Planner (RLP)
3
+ ===============================
4
+
5
+ The "reasoning core" of MicroForge. Inspired by:
6
+ - RIN (Recurrent Interface Networks, Jabri et al. 2022): decoupled latent tokens
7
+ that iteratively refine via cross-attention to image tokens
8
+ - DiMSUM shared attention: lightweight global context
9
+ - HRM/TRM recursive reasoning: iterative refinement of a compact state
10
+
11
+ The RLP maintains a fixed set of K latent tokens (the "plan") that:
12
+ 1. READ from the noised image latent to understand current state
13
+ 2. REASON internally via self-attention over plan tokens
14
+ 3. WRITE back to the image latent to guide denoising
15
+
16
+ This is applied BEFORE each denoising step, creating a planning loop:
17
+ plan_0 = init(text_emb)
18
+ for step in diffusion_steps:
19
+ plan_{s+1} = RLP.read_reason_write(z_s, plan_s, text_emb)
20
+ z_{s+1} = backbone(z_s, t_s, text_emb, plan_{s+1})
21
+
22
+ Key insight (from RIN): the plan tokens are much fewer than image tokens
23
+ (K=32 vs N=256+), so self-attention over plan is cheap. Cross-attention
24
+ (K queries, N keys) is O(K*N) which is small when K << N.
25
+
26
+ This gives the model a "thinking" mechanism: it can reason about the
27
+ image at a higher level before committing to pixel-level changes.
28
+
29
+ For editing: the planner can compare source and target latents and
30
+ plan what needs to change (like a diff operation in latent space).
31
+ """
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ from typing import Optional, Tuple
37
+
38
+
39
+ class PlannerReadWrite(nn.Module):
40
+ """
41
+ Cross-attention interface between plan tokens and image tokens.
42
+ READ: plan attends to image -> updates plan
43
+ WRITE: image attends to plan -> plan guides image
44
+ """
45
+ def __init__(self, dim: int, num_heads: int = 4):
46
+ super().__init__()
47
+ self.head_dim = dim // num_heads
48
+ self.num_heads = num_heads
49
+
50
+ # Read: plan tokens query, image tokens are keys/values
51
+ self.read_q = nn.Linear(dim, dim, bias=False)
52
+ self.read_kv = nn.Linear(dim, dim * 2, bias=False)
53
+ self.read_out = nn.Linear(dim, dim, bias=False)
54
+ self.read_norm_plan = nn.LayerNorm(dim)
55
+ self.read_norm_img = nn.LayerNorm(dim)
56
+
57
+ # Write: image tokens query, plan tokens are keys/values
58
+ self.write_q = nn.Linear(dim, dim, bias=False)
59
+ self.write_kv = nn.Linear(dim, dim * 2, bias=False)
60
+ self.write_out = nn.Linear(dim, dim, bias=False)
61
+ self.write_norm_img = nn.LayerNorm(dim)
62
+ self.write_norm_plan = nn.LayerNorm(dim)
63
+
64
+ def _attention(self, q, k, v):
65
+ B, H, N, D = q.shape
66
+ scale = D ** -0.5
67
+ attn = (q @ k.transpose(-2, -1)) * scale
68
+ attn = attn.softmax(dim=-1)
69
+ return attn @ v
70
+
71
+ def read(self, plan: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
72
+ """Plan reads from image. plan: [B,K,D], image: [B,N,D] -> updated plan [B,K,D]"""
73
+ B, K, D = plan.shape
74
+ N = image.shape[1]
75
+
76
+ q = self.read_q(self.read_norm_plan(plan)).reshape(B, K, self.num_heads, self.head_dim).transpose(1, 2)
77
+ kv = self.read_kv(self.read_norm_img(image)).reshape(B, N, 2, self.num_heads, self.head_dim)
78
+ k, v = kv[:, :, 0].transpose(1, 2), kv[:, :, 1].transpose(1, 2)
79
+
80
+ out = self._attention(q, k, v)
81
+ out = out.transpose(1, 2).reshape(B, K, D)
82
+ return plan + self.read_out(out)
83
+
84
+ def write(self, image: torch.Tensor, plan: torch.Tensor) -> torch.Tensor:
85
+ """Plan writes to image. image: [B,N,D], plan: [B,K,D] -> updated image [B,N,D]"""
86
+ B, N, D = image.shape
87
+ K = plan.shape[1]
88
+
89
+ q = self.write_q(self.write_norm_img(image)).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
90
+ kv = self.write_kv(self.write_norm_plan(plan)).reshape(B, K, 2, self.num_heads, self.head_dim)
91
+ k, v = kv[:, :, 0].transpose(1, 2), kv[:, :, 1].transpose(1, 2)
92
+
93
+ out = self._attention(q, k, v)
94
+ out = out.transpose(1, 2).reshape(B, N, D)
95
+ return image + self.write_out(out)
96
+
97
+
98
+ class PlannerReasoning(nn.Module):
99
+ """
100
+ Self-attention + FFN over plan tokens.
101
+ This is where the "thinking" happens - plan tokens reason about
102
+ what the image should look like.
103
+ """
104
+ def __init__(self, dim: int, num_heads: int = 4, ffn_expansion: int = 3):
105
+ super().__init__()
106
+ self.norm1 = nn.LayerNorm(dim)
107
+ self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
108
+ self.norm2 = nn.LayerNorm(dim)
109
+ self.ffn = nn.Sequential(
110
+ nn.Linear(dim, dim * ffn_expansion),
111
+ nn.GELU(),
112
+ nn.Linear(dim * ffn_expansion, dim),
113
+ )
114
+ # Condition integration
115
+ self.cond_proj = nn.Linear(dim, dim * 2) # scale and shift
116
+
117
+ def forward(self, plan: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
118
+ """
119
+ plan: [B, K, D]
120
+ cond: [B, D] (timestep + text condition)
121
+ """
122
+ # Self-attention over plan tokens
123
+ h = self.norm1(plan)
124
+ h, _ = self.attn(h, h, h)
125
+ plan = plan + h
126
+
127
+ # Conditioned FFN
128
+ params = self.cond_proj(cond).unsqueeze(1) # [B, 1, 2D]
129
+ scale, shift = params.chunk(2, dim=-1)
130
+ h = self.norm2(plan)
131
+ h = h * (1 + scale) + shift
132
+ plan = plan + self.ffn(h)
133
+
134
+ return plan
135
+
136
+
137
+ class RecurrentLatentPlanner(nn.Module):
138
+ """
139
+ Recurrent Latent Planner (RLP).
140
+
141
+ Maintains K latent plan tokens that iteratively refine across
142
+ denoising steps. Each refinement involves:
143
+ 1. READ: plan attends to current noised image
144
+ 2. REASON: plan tokens self-attend and process with FFN
145
+ 3. WRITE: plan injects guidance back into image tokens
146
+
147
+ The plan carries forward across denoising steps via latent self-conditioning
148
+ (from RIN). At step s, the plan from step s-1 is used as initialization,
149
+ creating a persistent "memory" of the generation process.
150
+
151
+ Parameters:
152
+ - num_plan_tokens: K, number of plan tokens (default 32)
153
+ - dim: token dimension
154
+ - num_layers: depth of reasoning (default 2)
155
+ - text_dim: dimension of text embeddings for initialization
156
+
157
+ Memory: K * D * 4 bytes per plan = 32 * 384 * 4 = 49KB (negligible)
158
+ Compute: O(K^2 + K*N) per layer (K=32, N=256 -> ~40K ops, vs N^2=65K for full attention)
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ num_plan_tokens: int = 32,
164
+ dim: int = 384,
165
+ text_dim: int = 768,
166
+ latent_channels: int = 32,
167
+ num_layers: int = 2,
168
+ num_heads: int = 4,
169
+ ):
170
+ super().__init__()
171
+ self.num_plan_tokens = num_plan_tokens
172
+ self.dim = dim
173
+
174
+ # Input projection: map raw latent channels to planner dim
175
+ self.image_proj = nn.Linear(latent_channels, dim)
176
+
177
+ # Learnable initial plan tokens
178
+ self.init_tokens = nn.Parameter(torch.randn(1, num_plan_tokens, dim) * 0.02)
179
+
180
+ # Text-to-plan projection (initialize plan from text)
181
+ self.text_to_plan = nn.Sequential(
182
+ nn.Linear(text_dim, dim),
183
+ nn.SiLU(),
184
+ nn.Linear(dim, dim),
185
+ )
186
+
187
+ # Timestep projection
188
+ self.time_proj = nn.Sequential(
189
+ nn.Linear(dim, dim * 4),
190
+ nn.SiLU(),
191
+ nn.Linear(dim * 4, dim),
192
+ )
193
+
194
+ # Stacked read-reason-write layers
195
+ self.layers = nn.ModuleList()
196
+ for _ in range(num_layers):
197
+ self.layers.append(nn.ModuleDict({
198
+ 'read_write': PlannerReadWrite(dim, num_heads),
199
+ 'reason': PlannerReasoning(dim, num_heads),
200
+ }))
201
+
202
+ # Final projection to backbone-compatible tokens (must match text_dim)
203
+ self.output_proj = nn.Linear(dim, text_dim)
204
+ self.output_norm = nn.LayerNorm(dim)
205
+
206
+ # Self-conditioning weight (learnable, from RIN)
207
+ self.self_cond_weight = nn.Parameter(torch.tensor(0.5))
208
+
209
+ def initialize_plan(
210
+ self,
211
+ text_pooled: torch.Tensor,
212
+ batch_size: int,
213
+ prev_plan: Optional[torch.Tensor] = None,
214
+ ) -> torch.Tensor:
215
+ """
216
+ Initialize plan tokens from text and (optionally) previous plan.
217
+
218
+ text_pooled: [B, text_dim]
219
+ prev_plan: [B, K, D] from previous denoising step (latent self-conditioning)
220
+ """
221
+ # Learnable base + text-guided initialization
222
+ plan = self.init_tokens.expand(batch_size, -1, -1)
223
+ text_cond = self.text_to_plan(text_pooled).unsqueeze(1) # [B, 1, D]
224
+ plan = plan + text_cond
225
+
226
+ # Latent self-conditioning from previous step
227
+ if prev_plan is not None:
228
+ w = torch.sigmoid(self.self_cond_weight)
229
+ plan = w * prev_plan + (1 - w) * plan
230
+
231
+ return plan
232
+
233
+ def forward(
234
+ self,
235
+ image_tokens: torch.Tensor,
236
+ plan: torch.Tensor,
237
+ t_emb: torch.Tensor,
238
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
239
+ """
240
+ Full read-reason-write cycle.
241
+
242
+ Args:
243
+ image_tokens: [B, N, D] - patchified noised image latent
244
+ plan: [B, K, D] - current plan tokens
245
+ t_emb: [B, D] - timestep embedding
246
+
247
+ Returns:
248
+ updated_plan: [B, K, D] - refined plan
249
+ planner_output: [B, K, D] - tokens to inject into backbone
250
+ """
251
+ cond = t_emb # Could add more conditioning here
252
+
253
+ # Project image tokens to planner dimension
254
+ image_tokens = self.image_proj(image_tokens)
255
+
256
+ for layer in self.layers:
257
+ # READ: plan learns from image
258
+ plan = layer['read_write'].read(plan, image_tokens)
259
+ # REASON: plan self-refines
260
+ plan = layer['reason'](plan, cond)
261
+ # WRITE: plan guides image (optional, only in advanced mode)
262
+ # image_tokens = layer['read_write'].write(image_tokens, plan)
263
+
264
+ # Project plan tokens for backbone injection
265
+ output = self.output_proj(self.output_norm(plan))
266
+ return plan, output
267
+
268
+ def get_plan_size_bytes(self) -> int:
269
+ """Return size of plan state in bytes (for memory budgeting)."""
270
+ return self.num_plan_tokens * self.dim * 4 # float32