asdf98 commited on
Commit
a8d6acc
·
verified ·
1 Parent(s): 99a5f54

v2: Use pre-trained SD-VAE, fix all bugs, pre-cache everything, massive speedup

Browse files
Files changed (1) hide show
  1. iris_model.py +72 -59
iris_model.py CHANGED
@@ -989,71 +989,105 @@ class IRISGenerator(nn.Module):
989
  # ============================================================================
990
 
991
  class IRIS(nn.Module):
992
- """Complete IRIS system: VAE + Generator.
993
 
994
- For training: use train_step() which handles noise scheduling.
995
- For inference: use generate() which runs the full pipeline.
 
996
  """
997
- def __init__(self, config: IRISConfig):
998
  super().__init__()
999
  self.config = config
1000
- self.vae = WaveletVAE(config)
1001
  self.generator = IRISGenerator(config)
 
 
 
1002
 
1003
  def encode(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1004
- """Encode images to latent space."""
 
1005
  return self.vae.encode(images)
1006
 
1007
  def decode(self, z: torch.Tensor) -> torch.Tensor:
1008
- """Decode latent to images."""
 
1009
  return self.vae.decode(z)
1010
 
1011
- def get_velocity_target(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
 
1012
  """Rectified flow velocity target: v = noise - z_0."""
1013
  return noise - z_0
1014
 
1015
- def add_noise(self, z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
 
1016
  """Rectified flow forward process: z_t = (1-t)*z_0 + t*noise."""
1017
  t_expand = t[:, None, None, None]
1018
  return (1 - t_expand) * z_0 + t_expand * noise
1019
 
1020
- def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
1021
- """Sample timesteps from logit-normal distribution (from SD3/RF).
1022
- Concentrates sampling on intermediate timesteps where learning is hardest.
1023
- """
1024
  u = torch.randn(batch_size, device=device)
1025
- t = torch.sigmoid(u) # Logit-normal with mean=0, std=1
1026
- # Clamp to avoid t=0 and t=1
1027
  t = t.clamp(1e-5, 1 - 1e-5)
1028
  return t
1029
 
1030
- def train_step(
1031
  self,
1032
- images: torch.Tensor,
1033
  text_tokens: torch.Tensor,
1034
  num_iterations: Optional[int] = None,
1035
  ) -> dict:
1036
- """Single training step for rectified flow.
1037
 
1038
- Returns dict with loss and diagnostics.
 
 
1039
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1040
  B = images.shape[0]
1041
  device = images.device
1042
 
1043
- # Encode to latent
1044
  z_0, mean, logvar = self.encode(images)
1045
-
1046
- # Sample noise and timesteps
1047
  noise = torch.randn_like(z_0)
1048
  t = self.sample_timesteps(B, device)
1049
-
1050
- # Create noisy latent
1051
  z_t = self.add_noise(z_0, noise, t)
1052
 
1053
- # Predict velocity
1054
- # Randomly sample iteration count for training robustness
1055
  if num_iterations is None:
1056
- r_choices = [4, 6, 8, 10, 12]
1057
  r = r_choices[torch.randint(0, len(r_choices), (1,)).item()]
1058
  else:
1059
  r = num_iterations
@@ -1061,15 +1095,9 @@ class IRIS(nn.Module):
1061
  v_pred = self.generator(z_t, t, text_tokens, num_iterations=r)
1062
  v_target = self.get_velocity_target(z_0, noise)
1063
 
1064
- # SNR-weighted loss (from Rectified Flow paper)
1065
- # w(t) = t / (1 - t) — emphasizes high-noise timesteps
1066
  w = t / (1 - t + 1e-8)
1067
  w = w[:, None, None, None]
1068
-
1069
- # Velocity matching loss
1070
  velocity_loss = (w * (v_pred - v_target).pow(2)).mean()
1071
-
1072
- # VAE KL loss
1073
  kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean()
1074
 
1075
  return {
@@ -1080,7 +1108,7 @@ class IRIS(nn.Module):
1080
  }
1081
 
1082
  @torch.no_grad()
1083
- def generate(
1084
  self,
1085
  text_tokens: torch.Tensor,
1086
  num_steps: int = 4,
@@ -1088,14 +1116,9 @@ class IRIS(nn.Module):
1088
  cfg_scale: float = 4.0,
1089
  seed: Optional[int] = None,
1090
  ) -> torch.Tensor:
1091
- """Generate images from text conditioning using Euler solver.
1092
-
1093
- Args:
1094
- text_tokens: [B, S, text_dim] CLIP text embeddings
1095
- num_steps: Number of ODE solver steps (1-50)
1096
- num_iterations: Core iterations per step (quality budget)
1097
- cfg_scale: Classifier-free guidance scale
1098
- seed: Random seed for reproducibility
1099
  """
1100
  B, S, _ = text_tokens.shape
1101
  device = text_tokens.device
@@ -1103,35 +1126,25 @@ class IRIS(nn.Module):
1103
  if seed is not None:
1104
  torch.manual_seed(seed)
1105
 
1106
- # Start from pure noise
1107
  z = torch.randn(B, self.config.latent_channels,
1108
  self.config.latent_spatial, self.config.latent_spatial,
1109
  device=device)
1110
 
1111
- # Euler solver for rectified flow ODE: dz/dt = -v(z, t)
1112
- # Integrate from t=1 (noise) to t=0 (data)
1113
  dt = 1.0 / num_steps
1114
-
1115
  for step in range(num_steps):
1116
  t_val = 1.0 - step * dt
1117
  t = torch.full((B,), t_val, device=device)
1118
 
1119
- # Predict velocity
1120
  v = self.generator(z, t, text_tokens, num_iterations=num_iterations)
1121
 
1122
- # Classifier-free guidance (if cfg_scale > 1)
1123
  if cfg_scale > 1.0:
1124
  null_tokens = torch.zeros_like(text_tokens)
1125
  v_uncond = self.generator(z, t, null_tokens, num_iterations=num_iterations)
1126
  v = v_uncond + cfg_scale * (v - v_uncond)
1127
 
1128
- # Euler step: z = z - dt * v
1129
  z = z - dt * v
1130
 
1131
- # Decode to image
1132
- images = self.decode(z)
1133
- images = images.clamp(-1, 1)
1134
- return images
1135
 
1136
 
1137
  # ============================================================================
@@ -1162,9 +1175,9 @@ def estimate_memory_mb(model: nn.Module, dtype=torch.float16) -> float:
1162
 
1163
 
1164
  def create_iris_small(latent_spatial: int = 32) -> IRIS:
1165
- """Create IRIS-Small: ~75M generator params, suitable for mobile."""
1166
  config = IRISConfig(
1167
- latent_channels=16,
1168
  latent_spatial=latent_spatial,
1169
  hidden_dim=512,
1170
  num_heads=8,
@@ -1187,9 +1200,9 @@ def create_iris_small(latent_spatial: int = 32) -> IRIS:
1187
 
1188
 
1189
  def create_iris_tiny(latent_spatial: int = 32) -> IRIS:
1190
- """Create IRIS-Tiny: ~30M generator params, ultra-mobile."""
1191
  config = IRISConfig(
1192
- latent_channels=8,
1193
  latent_spatial=latent_spatial,
1194
  hidden_dim=384,
1195
  num_heads=6,
@@ -1212,9 +1225,9 @@ def create_iris_tiny(latent_spatial: int = 32) -> IRIS:
1212
 
1213
 
1214
  def create_iris_base(latent_spatial: int = 32) -> IRIS:
1215
- """Create IRIS-Base: ~150M generator params, quality-focused."""
1216
  config = IRISConfig(
1217
- latent_channels=16,
1218
  latent_spatial=latent_spatial,
1219
  hidden_dim=768,
1220
  num_heads=12,
 
989
  # ============================================================================
990
 
991
  class IRIS(nn.Module):
992
+ """Complete IRIS system: Generator + optional built-in VAE.
993
 
994
+ For training with external VAE (recommended): use train_step_latent() with pre-encoded latents.
995
+ For training with built-in Wavelet VAE: use train_step() with raw images.
996
+ For inference: use generate_latent() to get latent, then decode externally.
997
  """
998
+ def __init__(self, config: IRISConfig, use_builtin_vae: bool = False):
999
  super().__init__()
1000
  self.config = config
 
1001
  self.generator = IRISGenerator(config)
1002
+
1003
+ # Built-in Wavelet VAE is optional — prefer pre-trained external VAE
1004
+ self.vae = WaveletVAE(config) if use_builtin_vae else None
1005
 
1006
  def encode(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1007
+ """Encode images via built-in VAE (only if use_builtin_vae=True)."""
1008
+ assert self.vae is not None, "No built-in VAE. Use an external VAE to encode images."
1009
  return self.vae.encode(images)
1010
 
1011
  def decode(self, z: torch.Tensor) -> torch.Tensor:
1012
+ """Decode latent via built-in VAE (only if use_builtin_vae=True)."""
1013
+ assert self.vae is not None, "No built-in VAE. Use an external VAE to decode latents."
1014
  return self.vae.decode(z)
1015
 
1016
+ @staticmethod
1017
+ def get_velocity_target(z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
1018
  """Rectified flow velocity target: v = noise - z_0."""
1019
  return noise - z_0
1020
 
1021
+ @staticmethod
1022
+ def add_noise(z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
1023
  """Rectified flow forward process: z_t = (1-t)*z_0 + t*noise."""
1024
  t_expand = t[:, None, None, None]
1025
  return (1 - t_expand) * z_0 + t_expand * noise
1026
 
1027
+ @staticmethod
1028
+ def sample_timesteps(batch_size: int, device: torch.device) -> torch.Tensor:
1029
+ """Sample timesteps from logit-normal distribution (from SD3/RF)."""
 
1030
  u = torch.randn(batch_size, device=device)
1031
+ t = torch.sigmoid(u)
 
1032
  t = t.clamp(1e-5, 1 - 1e-5)
1033
  return t
1034
 
1035
+ def train_step_latent(
1036
  self,
1037
+ z_0: torch.Tensor,
1038
  text_tokens: torch.Tensor,
1039
  num_iterations: Optional[int] = None,
1040
  ) -> dict:
1041
+ """Training step on PRE-ENCODED latents (recommended path).
1042
 
1043
+ Use this with an external pre-trained VAE:
1044
+ z_0 = external_vae.encode(images) # done outside
1045
+ result = iris.train_step_latent(z_0, text_tokens)
1046
  """
1047
+ B = z_0.shape[0]
1048
+ device = z_0.device
1049
+
1050
+ noise = torch.randn_like(z_0)
1051
+ t = self.sample_timesteps(B, device)
1052
+ z_t = self.add_noise(z_0, noise, t)
1053
+
1054
+ if num_iterations is None:
1055
+ r_choices = [3, 4, 5, 6]
1056
+ r = r_choices[torch.randint(0, len(r_choices), (1,)).item()]
1057
+ else:
1058
+ r = num_iterations
1059
+
1060
+ v_pred = self.generator(z_t, t, text_tokens, num_iterations=r)
1061
+ v_target = self.get_velocity_target(z_0, noise)
1062
+
1063
+ w = t / (1 - t + 1e-8)
1064
+ w = w[:, None, None, None]
1065
+ velocity_loss = (w * (v_pred - v_target).pow(2)).mean()
1066
+
1067
+ return {
1068
+ 'loss': velocity_loss,
1069
+ 'velocity_loss': velocity_loss.item(),
1070
+ 'mean_t': t.mean().item(),
1071
+ }
1072
+
1073
+ def train_step(
1074
+ self,
1075
+ images: torch.Tensor,
1076
+ text_tokens: torch.Tensor,
1077
+ num_iterations: Optional[int] = None,
1078
+ ) -> dict:
1079
+ """Training step with built-in Wavelet VAE (legacy path)."""
1080
+ assert self.vae is not None, "No built-in VAE. Use train_step_latent() instead."
1081
  B = images.shape[0]
1082
  device = images.device
1083
 
 
1084
  z_0, mean, logvar = self.encode(images)
 
 
1085
  noise = torch.randn_like(z_0)
1086
  t = self.sample_timesteps(B, device)
 
 
1087
  z_t = self.add_noise(z_0, noise, t)
1088
 
 
 
1089
  if num_iterations is None:
1090
+ r_choices = [3, 4, 5, 6]
1091
  r = r_choices[torch.randint(0, len(r_choices), (1,)).item()]
1092
  else:
1093
  r = num_iterations
 
1095
  v_pred = self.generator(z_t, t, text_tokens, num_iterations=r)
1096
  v_target = self.get_velocity_target(z_0, noise)
1097
 
 
 
1098
  w = t / (1 - t + 1e-8)
1099
  w = w[:, None, None, None]
 
 
1100
  velocity_loss = (w * (v_pred - v_target).pow(2)).mean()
 
 
1101
  kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean()
1102
 
1103
  return {
 
1108
  }
1109
 
1110
  @torch.no_grad()
1111
+ def generate_latent(
1112
  self,
1113
  text_tokens: torch.Tensor,
1114
  num_steps: int = 4,
 
1116
  cfg_scale: float = 4.0,
1117
  seed: Optional[int] = None,
1118
  ) -> torch.Tensor:
1119
+ """Generate latent (decode externally with your VAE).
1120
+
1121
+ Returns z_0 latent tensor, NOT decoded image.
 
 
 
 
 
1122
  """
1123
  B, S, _ = text_tokens.shape
1124
  device = text_tokens.device
 
1126
  if seed is not None:
1127
  torch.manual_seed(seed)
1128
 
 
1129
  z = torch.randn(B, self.config.latent_channels,
1130
  self.config.latent_spatial, self.config.latent_spatial,
1131
  device=device)
1132
 
 
 
1133
  dt = 1.0 / num_steps
 
1134
  for step in range(num_steps):
1135
  t_val = 1.0 - step * dt
1136
  t = torch.full((B,), t_val, device=device)
1137
 
 
1138
  v = self.generator(z, t, text_tokens, num_iterations=num_iterations)
1139
 
 
1140
  if cfg_scale > 1.0:
1141
  null_tokens = torch.zeros_like(text_tokens)
1142
  v_uncond = self.generator(z, t, null_tokens, num_iterations=num_iterations)
1143
  v = v_uncond + cfg_scale * (v - v_uncond)
1144
 
 
1145
  z = z - dt * v
1146
 
1147
+ return z
 
 
 
1148
 
1149
 
1150
  # ============================================================================
 
1175
 
1176
 
1177
  def create_iris_small(latent_spatial: int = 32) -> IRIS:
1178
+ """Create IRIS-Small for SD-VAE latent space (4ch, 8× downsample)."""
1179
  config = IRISConfig(
1180
+ latent_channels=4,
1181
  latent_spatial=latent_spatial,
1182
  hidden_dim=512,
1183
  num_heads=8,
 
1200
 
1201
 
1202
  def create_iris_tiny(latent_spatial: int = 32) -> IRIS:
1203
+ """Create IRIS-Tiny for SD-VAE latent space (4ch, 8× downsample)."""
1204
  config = IRISConfig(
1205
+ latent_channels=4,
1206
  latent_spatial=latent_spatial,
1207
  hidden_dim=384,
1208
  num_heads=6,
 
1225
 
1226
 
1227
  def create_iris_base(latent_spatial: int = 32) -> IRIS:
1228
+ """Create IRIS-Base for SD-VAE latent space (4ch, 8× downsample)."""
1229
  config = IRISConfig(
1230
+ latent_channels=4,
1231
  latent_spatial=latent_spatial,
1232
  hidden_dim=768,
1233
  num_heads=12,