Fix: register context_proj as proper nn.Module (was lazy, not saved in checkpoints)
Browse files- iris/model.py +31 -6
iris/model.py
CHANGED
|
@@ -65,8 +65,16 @@ class IRIS(nn.Module):
|
|
| 65 |
"""
|
| 66 |
IRIS: Iterative Refinement Image Synthesizer.
|
| 67 |
Predicts velocity v_theta(z_t, t, c) for flow matching.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
-
def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6,
|
|
|
|
|
|
|
| 70 |
super().__init__()
|
| 71 |
self.latent_channels = latent_channels
|
| 72 |
self.dim = dim
|
|
@@ -75,8 +83,17 @@ class IRIS(nn.Module):
|
|
| 75 |
self.patchify = Patchify(latent_channels, dim, patch_size)
|
| 76 |
self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
|
| 77 |
spatial_size = 4 # default for 16x16 latent with ps=4
|
| 78 |
-
self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads,
|
|
|
|
|
|
|
| 79 |
self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
self._init_weights()
|
| 81 |
|
| 82 |
def _init_weights(self):
|
|
@@ -95,10 +112,18 @@ class IRIS(nn.Module):
|
|
| 95 |
|
| 96 |
def forward(self, z_t, t, context, num_iterations=4):
|
| 97 |
tokens, H_tok, W_tok = self.patchify(z_t)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
context = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
|
| 103 |
return self.unpatchify(refined, H_tok, W_tok)
|
| 104 |
|
|
|
|
| 65 |
"""
|
| 66 |
IRIS: Iterative Refinement Image Synthesizer.
|
| 67 |
Predicts velocity v_theta(z_t, t, c) for flow matching.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
text_dim: dimension of text encoder output. If different from dim,
|
| 71 |
+
a learned linear projection is applied. Set to 384 for
|
| 72 |
+
all-MiniLM-L6-v2, 512 for CLIP, etc. Set to None or
|
| 73 |
+
equal to dim to skip projection.
|
| 74 |
"""
|
| 75 |
+
def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6,
|
| 76 |
+
num_heads=8, max_iterations=8, ffn_expansion=2,
|
| 77 |
+
gradient_checkpointing=True, text_dim=None):
|
| 78 |
super().__init__()
|
| 79 |
self.latent_channels = latent_channels
|
| 80 |
self.dim = dim
|
|
|
|
| 83 |
self.patchify = Patchify(latent_channels, dim, patch_size)
|
| 84 |
self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
|
| 85 |
spatial_size = 4 # default for 16x16 latent with ps=4
|
| 86 |
+
self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads,
|
| 87 |
+
spatial_size=spatial_size, max_iterations=max_iterations,
|
| 88 |
+
ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
|
| 89 |
self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
|
| 90 |
+
|
| 91 |
+
# Text projection: maps text encoder dim to model dim if they differ
|
| 92 |
+
if text_dim is not None and text_dim != dim:
|
| 93 |
+
self.context_proj = nn.Linear(text_dim, dim, bias=False)
|
| 94 |
+
else:
|
| 95 |
+
self.context_proj = None
|
| 96 |
+
|
| 97 |
self._init_weights()
|
| 98 |
|
| 99 |
def _init_weights(self):
|
|
|
|
| 112 |
|
| 113 |
def forward(self, z_t, t, context, num_iterations=4):
|
| 114 |
tokens, H_tok, W_tok = self.patchify(z_t)
|
| 115 |
+
|
| 116 |
+
# Project text embeddings to model dim if needed
|
| 117 |
+
if self.context_proj is not None:
|
| 118 |
+
context = self.context_proj(context)
|
| 119 |
+
elif context.shape[-1] != self.dim:
|
| 120 |
+
# Fallback: lazy projection for backwards compat
|
| 121 |
+
if not hasattr(self, '_lazy_context_proj'):
|
| 122 |
+
self._lazy_context_proj = nn.Linear(
|
| 123 |
+
context.shape[-1], self.dim, bias=False
|
| 124 |
+
).to(context.device, context.dtype)
|
| 125 |
+
context = self._lazy_context_proj(context)
|
| 126 |
+
|
| 127 |
refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
|
| 128 |
return self.unpatchify(refined, H_tok, W_tok)
|
| 129 |
|