asdf98 commited on
Commit
c98929a
·
verified ·
1 Parent(s): fe73fcc

Fix: register context_proj as proper nn.Module (was lazy, not saved in checkpoints)

Browse files
Files changed (1) hide show
  1. iris/model.py +31 -6
iris/model.py CHANGED
@@ -65,8 +65,16 @@ class IRIS(nn.Module):
65
  """
66
  IRIS: Iterative Refinement Image Synthesizer.
67
  Predicts velocity v_theta(z_t, t, c) for flow matching.
 
 
 
 
 
 
68
  """
69
- def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6, num_heads=8, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True):
 
 
70
  super().__init__()
71
  self.latent_channels = latent_channels
72
  self.dim = dim
@@ -75,8 +83,17 @@ class IRIS(nn.Module):
75
  self.patchify = Patchify(latent_channels, dim, patch_size)
76
  self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
77
  spatial_size = 4 # default for 16x16 latent with ps=4
78
- self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads, spatial_size=spatial_size, max_iterations=max_iterations, ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
 
 
79
  self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
 
 
 
 
 
 
 
80
  self._init_weights()
81
 
82
  def _init_weights(self):
@@ -95,10 +112,18 @@ class IRIS(nn.Module):
95
 
96
  def forward(self, z_t, t, context, num_iterations=4):
97
  tokens, H_tok, W_tok = self.patchify(z_t)
98
- if context.shape[-1] != self.dim:
99
- if not hasattr(self, '_context_proj'):
100
- self._context_proj = nn.Linear(context.shape[-1], self.dim, bias=False).to(context.device, context.dtype)
101
- context = self._context_proj(context)
 
 
 
 
 
 
 
 
102
  refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
103
  return self.unpatchify(refined, H_tok, W_tok)
104
 
 
65
  """
66
  IRIS: Iterative Refinement Image Synthesizer.
67
  Predicts velocity v_theta(z_t, t, c) for flow matching.
68
+
69
+ Args:
70
+ text_dim: dimension of text encoder output. If different from dim,
71
+ a learned linear projection is applied. Set to 384 for
72
+ all-MiniLM-L6-v2, 512 for CLIP, etc. Set to None or
73
+ equal to dim to skip projection.
74
  """
75
+ def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6,
76
+ num_heads=8, max_iterations=8, ffn_expansion=2,
77
+ gradient_checkpointing=True, text_dim=None):
78
  super().__init__()
79
  self.latent_channels = latent_channels
80
  self.dim = dim
 
83
  self.patchify = Patchify(latent_channels, dim, patch_size)
84
  self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
85
  spatial_size = 4 # default for 16x16 latent with ps=4
86
+ self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads,
87
+ spatial_size=spatial_size, max_iterations=max_iterations,
88
+ ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
89
  self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
90
+
91
+ # Text projection: maps text encoder dim to model dim if they differ
92
+ if text_dim is not None and text_dim != dim:
93
+ self.context_proj = nn.Linear(text_dim, dim, bias=False)
94
+ else:
95
+ self.context_proj = None
96
+
97
  self._init_weights()
98
 
99
  def _init_weights(self):
 
112
 
113
  def forward(self, z_t, t, context, num_iterations=4):
114
  tokens, H_tok, W_tok = self.patchify(z_t)
115
+
116
+ # Project text embeddings to model dim if needed
117
+ if self.context_proj is not None:
118
+ context = self.context_proj(context)
119
+ elif context.shape[-1] != self.dim:
120
+ # Fallback: lazy projection for backwards compat
121
+ if not hasattr(self, '_lazy_context_proj'):
122
+ self._lazy_context_proj = nn.Linear(
123
+ context.shape[-1], self.dim, bias=False
124
+ ).to(context.device, context.dtype)
125
+ context = self._lazy_context_proj(context)
126
+
127
  refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
128
  return self.unpatchify(refined, H_tok, W_tok)
129