v2: Use pre-trained SD-VAE, fix all bugs, pre-cache everything, massive speedup
Browse files- iris_model.py +72 -59
iris_model.py
CHANGED
|
@@ -989,71 +989,105 @@ class IRISGenerator(nn.Module):
|
|
| 989 |
# ============================================================================
|
| 990 |
|
| 991 |
class IRIS(nn.Module):
|
| 992 |
-
"""Complete IRIS system:
|
| 993 |
|
| 994 |
-
For training: use
|
| 995 |
-
For
|
|
|
|
| 996 |
"""
|
| 997 |
-
def __init__(self, config: IRISConfig):
|
| 998 |
super().__init__()
|
| 999 |
self.config = config
|
| 1000 |
-
self.vae = WaveletVAE(config)
|
| 1001 |
self.generator = IRISGenerator(config)
|
|
|
|
|
|
|
|
|
|
| 1002 |
|
| 1003 |
def encode(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1004 |
-
"""Encode images
|
|
|
|
| 1005 |
return self.vae.encode(images)
|
| 1006 |
|
| 1007 |
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| 1008 |
-
"""Decode latent
|
|
|
|
| 1009 |
return self.vae.decode(z)
|
| 1010 |
|
| 1011 |
-
|
|
|
|
| 1012 |
"""Rectified flow velocity target: v = noise - z_0."""
|
| 1013 |
return noise - z_0
|
| 1014 |
|
| 1015 |
-
|
|
|
|
| 1016 |
"""Rectified flow forward process: z_t = (1-t)*z_0 + t*noise."""
|
| 1017 |
t_expand = t[:, None, None, None]
|
| 1018 |
return (1 - t_expand) * z_0 + t_expand * noise
|
| 1019 |
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
"""
|
| 1024 |
u = torch.randn(batch_size, device=device)
|
| 1025 |
-
t = torch.sigmoid(u)
|
| 1026 |
-
# Clamp to avoid t=0 and t=1
|
| 1027 |
t = t.clamp(1e-5, 1 - 1e-5)
|
| 1028 |
return t
|
| 1029 |
|
| 1030 |
-
def
|
| 1031 |
self,
|
| 1032 |
-
|
| 1033 |
text_tokens: torch.Tensor,
|
| 1034 |
num_iterations: Optional[int] = None,
|
| 1035 |
) -> dict:
|
| 1036 |
-
"""
|
| 1037 |
|
| 1038 |
-
|
|
|
|
|
|
|
| 1039 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1040 |
B = images.shape[0]
|
| 1041 |
device = images.device
|
| 1042 |
|
| 1043 |
-
# Encode to latent
|
| 1044 |
z_0, mean, logvar = self.encode(images)
|
| 1045 |
-
|
| 1046 |
-
# Sample noise and timesteps
|
| 1047 |
noise = torch.randn_like(z_0)
|
| 1048 |
t = self.sample_timesteps(B, device)
|
| 1049 |
-
|
| 1050 |
-
# Create noisy latent
|
| 1051 |
z_t = self.add_noise(z_0, noise, t)
|
| 1052 |
|
| 1053 |
-
# Predict velocity
|
| 1054 |
-
# Randomly sample iteration count for training robustness
|
| 1055 |
if num_iterations is None:
|
| 1056 |
-
r_choices = [
|
| 1057 |
r = r_choices[torch.randint(0, len(r_choices), (1,)).item()]
|
| 1058 |
else:
|
| 1059 |
r = num_iterations
|
|
@@ -1061,15 +1095,9 @@ class IRIS(nn.Module):
|
|
| 1061 |
v_pred = self.generator(z_t, t, text_tokens, num_iterations=r)
|
| 1062 |
v_target = self.get_velocity_target(z_0, noise)
|
| 1063 |
|
| 1064 |
-
# SNR-weighted loss (from Rectified Flow paper)
|
| 1065 |
-
# w(t) = t / (1 - t) — emphasizes high-noise timesteps
|
| 1066 |
w = t / (1 - t + 1e-8)
|
| 1067 |
w = w[:, None, None, None]
|
| 1068 |
-
|
| 1069 |
-
# Velocity matching loss
|
| 1070 |
velocity_loss = (w * (v_pred - v_target).pow(2)).mean()
|
| 1071 |
-
|
| 1072 |
-
# VAE KL loss
|
| 1073 |
kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean()
|
| 1074 |
|
| 1075 |
return {
|
|
@@ -1080,7 +1108,7 @@ class IRIS(nn.Module):
|
|
| 1080 |
}
|
| 1081 |
|
| 1082 |
@torch.no_grad()
|
| 1083 |
-
def
|
| 1084 |
self,
|
| 1085 |
text_tokens: torch.Tensor,
|
| 1086 |
num_steps: int = 4,
|
|
@@ -1088,14 +1116,9 @@ class IRIS(nn.Module):
|
|
| 1088 |
cfg_scale: float = 4.0,
|
| 1089 |
seed: Optional[int] = None,
|
| 1090 |
) -> torch.Tensor:
|
| 1091 |
-
"""Generate
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
text_tokens: [B, S, text_dim] CLIP text embeddings
|
| 1095 |
-
num_steps: Number of ODE solver steps (1-50)
|
| 1096 |
-
num_iterations: Core iterations per step (quality budget)
|
| 1097 |
-
cfg_scale: Classifier-free guidance scale
|
| 1098 |
-
seed: Random seed for reproducibility
|
| 1099 |
"""
|
| 1100 |
B, S, _ = text_tokens.shape
|
| 1101 |
device = text_tokens.device
|
|
@@ -1103,35 +1126,25 @@ class IRIS(nn.Module):
|
|
| 1103 |
if seed is not None:
|
| 1104 |
torch.manual_seed(seed)
|
| 1105 |
|
| 1106 |
-
# Start from pure noise
|
| 1107 |
z = torch.randn(B, self.config.latent_channels,
|
| 1108 |
self.config.latent_spatial, self.config.latent_spatial,
|
| 1109 |
device=device)
|
| 1110 |
|
| 1111 |
-
# Euler solver for rectified flow ODE: dz/dt = -v(z, t)
|
| 1112 |
-
# Integrate from t=1 (noise) to t=0 (data)
|
| 1113 |
dt = 1.0 / num_steps
|
| 1114 |
-
|
| 1115 |
for step in range(num_steps):
|
| 1116 |
t_val = 1.0 - step * dt
|
| 1117 |
t = torch.full((B,), t_val, device=device)
|
| 1118 |
|
| 1119 |
-
# Predict velocity
|
| 1120 |
v = self.generator(z, t, text_tokens, num_iterations=num_iterations)
|
| 1121 |
|
| 1122 |
-
# Classifier-free guidance (if cfg_scale > 1)
|
| 1123 |
if cfg_scale > 1.0:
|
| 1124 |
null_tokens = torch.zeros_like(text_tokens)
|
| 1125 |
v_uncond = self.generator(z, t, null_tokens, num_iterations=num_iterations)
|
| 1126 |
v = v_uncond + cfg_scale * (v - v_uncond)
|
| 1127 |
|
| 1128 |
-
# Euler step: z = z - dt * v
|
| 1129 |
z = z - dt * v
|
| 1130 |
|
| 1131 |
-
|
| 1132 |
-
images = self.decode(z)
|
| 1133 |
-
images = images.clamp(-1, 1)
|
| 1134 |
-
return images
|
| 1135 |
|
| 1136 |
|
| 1137 |
# ============================================================================
|
|
@@ -1162,9 +1175,9 @@ def estimate_memory_mb(model: nn.Module, dtype=torch.float16) -> float:
|
|
| 1162 |
|
| 1163 |
|
| 1164 |
def create_iris_small(latent_spatial: int = 32) -> IRIS:
|
| 1165 |
-
"""Create IRIS-Small
|
| 1166 |
config = IRISConfig(
|
| 1167 |
-
latent_channels=
|
| 1168 |
latent_spatial=latent_spatial,
|
| 1169 |
hidden_dim=512,
|
| 1170 |
num_heads=8,
|
|
@@ -1187,9 +1200,9 @@ def create_iris_small(latent_spatial: int = 32) -> IRIS:
|
|
| 1187 |
|
| 1188 |
|
| 1189 |
def create_iris_tiny(latent_spatial: int = 32) -> IRIS:
|
| 1190 |
-
"""Create IRIS-Tiny
|
| 1191 |
config = IRISConfig(
|
| 1192 |
-
latent_channels=
|
| 1193 |
latent_spatial=latent_spatial,
|
| 1194 |
hidden_dim=384,
|
| 1195 |
num_heads=6,
|
|
@@ -1212,9 +1225,9 @@ def create_iris_tiny(latent_spatial: int = 32) -> IRIS:
|
|
| 1212 |
|
| 1213 |
|
| 1214 |
def create_iris_base(latent_spatial: int = 32) -> IRIS:
|
| 1215 |
-
"""Create IRIS-Base
|
| 1216 |
config = IRISConfig(
|
| 1217 |
-
latent_channels=
|
| 1218 |
latent_spatial=latent_spatial,
|
| 1219 |
hidden_dim=768,
|
| 1220 |
num_heads=12,
|
|
|
|
| 989 |
# ============================================================================
|
| 990 |
|
| 991 |
class IRIS(nn.Module):
|
| 992 |
+
"""Complete IRIS system: Generator + optional built-in VAE.
|
| 993 |
|
| 994 |
+
For training with external VAE (recommended): use train_step_latent() with pre-encoded latents.
|
| 995 |
+
For training with built-in Wavelet VAE: use train_step() with raw images.
|
| 996 |
+
For inference: use generate_latent() to get latent, then decode externally.
|
| 997 |
"""
|
| 998 |
+
def __init__(self, config: IRISConfig, use_builtin_vae: bool = False):
|
| 999 |
super().__init__()
|
| 1000 |
self.config = config
|
|
|
|
| 1001 |
self.generator = IRISGenerator(config)
|
| 1002 |
+
|
| 1003 |
+
# Built-in Wavelet VAE is optional — prefer pre-trained external VAE
|
| 1004 |
+
self.vae = WaveletVAE(config) if use_builtin_vae else None
|
| 1005 |
|
| 1006 |
def encode(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1007 |
+
"""Encode images via built-in VAE (only if use_builtin_vae=True)."""
|
| 1008 |
+
assert self.vae is not None, "No built-in VAE. Use an external VAE to encode images."
|
| 1009 |
return self.vae.encode(images)
|
| 1010 |
|
| 1011 |
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| 1012 |
+
"""Decode latent via built-in VAE (only if use_builtin_vae=True)."""
|
| 1013 |
+
assert self.vae is not None, "No built-in VAE. Use an external VAE to decode latents."
|
| 1014 |
return self.vae.decode(z)
|
| 1015 |
|
| 1016 |
+
@staticmethod
|
| 1017 |
+
def get_velocity_target(z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
|
| 1018 |
"""Rectified flow velocity target: v = noise - z_0."""
|
| 1019 |
return noise - z_0
|
| 1020 |
|
| 1021 |
+
@staticmethod
|
| 1022 |
+
def add_noise(z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 1023 |
"""Rectified flow forward process: z_t = (1-t)*z_0 + t*noise."""
|
| 1024 |
t_expand = t[:, None, None, None]
|
| 1025 |
return (1 - t_expand) * z_0 + t_expand * noise
|
| 1026 |
|
| 1027 |
+
@staticmethod
|
| 1028 |
+
def sample_timesteps(batch_size: int, device: torch.device) -> torch.Tensor:
|
| 1029 |
+
"""Sample timesteps from logit-normal distribution (from SD3/RF)."""
|
|
|
|
| 1030 |
u = torch.randn(batch_size, device=device)
|
| 1031 |
+
t = torch.sigmoid(u)
|
|
|
|
| 1032 |
t = t.clamp(1e-5, 1 - 1e-5)
|
| 1033 |
return t
|
| 1034 |
|
| 1035 |
+
def train_step_latent(
|
| 1036 |
self,
|
| 1037 |
+
z_0: torch.Tensor,
|
| 1038 |
text_tokens: torch.Tensor,
|
| 1039 |
num_iterations: Optional[int] = None,
|
| 1040 |
) -> dict:
|
| 1041 |
+
"""Training step on PRE-ENCODED latents (recommended path).
|
| 1042 |
|
| 1043 |
+
Use this with an external pre-trained VAE:
|
| 1044 |
+
z_0 = external_vae.encode(images) # done outside
|
| 1045 |
+
result = iris.train_step_latent(z_0, text_tokens)
|
| 1046 |
"""
|
| 1047 |
+
B = z_0.shape[0]
|
| 1048 |
+
device = z_0.device
|
| 1049 |
+
|
| 1050 |
+
noise = torch.randn_like(z_0)
|
| 1051 |
+
t = self.sample_timesteps(B, device)
|
| 1052 |
+
z_t = self.add_noise(z_0, noise, t)
|
| 1053 |
+
|
| 1054 |
+
if num_iterations is None:
|
| 1055 |
+
r_choices = [3, 4, 5, 6]
|
| 1056 |
+
r = r_choices[torch.randint(0, len(r_choices), (1,)).item()]
|
| 1057 |
+
else:
|
| 1058 |
+
r = num_iterations
|
| 1059 |
+
|
| 1060 |
+
v_pred = self.generator(z_t, t, text_tokens, num_iterations=r)
|
| 1061 |
+
v_target = self.get_velocity_target(z_0, noise)
|
| 1062 |
+
|
| 1063 |
+
w = t / (1 - t + 1e-8)
|
| 1064 |
+
w = w[:, None, None, None]
|
| 1065 |
+
velocity_loss = (w * (v_pred - v_target).pow(2)).mean()
|
| 1066 |
+
|
| 1067 |
+
return {
|
| 1068 |
+
'loss': velocity_loss,
|
| 1069 |
+
'velocity_loss': velocity_loss.item(),
|
| 1070 |
+
'mean_t': t.mean().item(),
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
+
def train_step(
|
| 1074 |
+
self,
|
| 1075 |
+
images: torch.Tensor,
|
| 1076 |
+
text_tokens: torch.Tensor,
|
| 1077 |
+
num_iterations: Optional[int] = None,
|
| 1078 |
+
) -> dict:
|
| 1079 |
+
"""Training step with built-in Wavelet VAE (legacy path)."""
|
| 1080 |
+
assert self.vae is not None, "No built-in VAE. Use train_step_latent() instead."
|
| 1081 |
B = images.shape[0]
|
| 1082 |
device = images.device
|
| 1083 |
|
|
|
|
| 1084 |
z_0, mean, logvar = self.encode(images)
|
|
|
|
|
|
|
| 1085 |
noise = torch.randn_like(z_0)
|
| 1086 |
t = self.sample_timesteps(B, device)
|
|
|
|
|
|
|
| 1087 |
z_t = self.add_noise(z_0, noise, t)
|
| 1088 |
|
|
|
|
|
|
|
| 1089 |
if num_iterations is None:
|
| 1090 |
+
r_choices = [3, 4, 5, 6]
|
| 1091 |
r = r_choices[torch.randint(0, len(r_choices), (1,)).item()]
|
| 1092 |
else:
|
| 1093 |
r = num_iterations
|
|
|
|
| 1095 |
v_pred = self.generator(z_t, t, text_tokens, num_iterations=r)
|
| 1096 |
v_target = self.get_velocity_target(z_0, noise)
|
| 1097 |
|
|
|
|
|
|
|
| 1098 |
w = t / (1 - t + 1e-8)
|
| 1099 |
w = w[:, None, None, None]
|
|
|
|
|
|
|
| 1100 |
velocity_loss = (w * (v_pred - v_target).pow(2)).mean()
|
|
|
|
|
|
|
| 1101 |
kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean()
|
| 1102 |
|
| 1103 |
return {
|
|
|
|
| 1108 |
}
|
| 1109 |
|
| 1110 |
@torch.no_grad()
|
| 1111 |
+
def generate_latent(
|
| 1112 |
self,
|
| 1113 |
text_tokens: torch.Tensor,
|
| 1114 |
num_steps: int = 4,
|
|
|
|
| 1116 |
cfg_scale: float = 4.0,
|
| 1117 |
seed: Optional[int] = None,
|
| 1118 |
) -> torch.Tensor:
|
| 1119 |
+
"""Generate latent (decode externally with your VAE).
|
| 1120 |
+
|
| 1121 |
+
Returns z_0 latent tensor, NOT decoded image.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1122 |
"""
|
| 1123 |
B, S, _ = text_tokens.shape
|
| 1124 |
device = text_tokens.device
|
|
|
|
| 1126 |
if seed is not None:
|
| 1127 |
torch.manual_seed(seed)
|
| 1128 |
|
|
|
|
| 1129 |
z = torch.randn(B, self.config.latent_channels,
|
| 1130 |
self.config.latent_spatial, self.config.latent_spatial,
|
| 1131 |
device=device)
|
| 1132 |
|
|
|
|
|
|
|
| 1133 |
dt = 1.0 / num_steps
|
|
|
|
| 1134 |
for step in range(num_steps):
|
| 1135 |
t_val = 1.0 - step * dt
|
| 1136 |
t = torch.full((B,), t_val, device=device)
|
| 1137 |
|
|
|
|
| 1138 |
v = self.generator(z, t, text_tokens, num_iterations=num_iterations)
|
| 1139 |
|
|
|
|
| 1140 |
if cfg_scale > 1.0:
|
| 1141 |
null_tokens = torch.zeros_like(text_tokens)
|
| 1142 |
v_uncond = self.generator(z, t, null_tokens, num_iterations=num_iterations)
|
| 1143 |
v = v_uncond + cfg_scale * (v - v_uncond)
|
| 1144 |
|
|
|
|
| 1145 |
z = z - dt * v
|
| 1146 |
|
| 1147 |
+
return z
|
|
|
|
|
|
|
|
|
|
| 1148 |
|
| 1149 |
|
| 1150 |
# ============================================================================
|
|
|
|
| 1175 |
|
| 1176 |
|
| 1177 |
def create_iris_small(latent_spatial: int = 32) -> IRIS:
|
| 1178 |
+
"""Create IRIS-Small for SD-VAE latent space (4ch, 8× downsample)."""
|
| 1179 |
config = IRISConfig(
|
| 1180 |
+
latent_channels=4,
|
| 1181 |
latent_spatial=latent_spatial,
|
| 1182 |
hidden_dim=512,
|
| 1183 |
num_heads=8,
|
|
|
|
| 1200 |
|
| 1201 |
|
| 1202 |
def create_iris_tiny(latent_spatial: int = 32) -> IRIS:
|
| 1203 |
+
"""Create IRIS-Tiny for SD-VAE latent space (4ch, 8× downsample)."""
|
| 1204 |
config = IRISConfig(
|
| 1205 |
+
latent_channels=4,
|
| 1206 |
latent_spatial=latent_spatial,
|
| 1207 |
hidden_dim=384,
|
| 1208 |
num_heads=6,
|
|
|
|
| 1225 |
|
| 1226 |
|
| 1227 |
def create_iris_base(latent_spatial: int = 32) -> IRIS:
|
| 1228 |
+
"""Create IRIS-Base for SD-VAE latent space (4ch, 8× downsample)."""
|
| 1229 |
config = IRISConfig(
|
| 1230 |
+
latent_channels=4,
|
| 1231 |
latent_spatial=latent_spatial,
|
| 1232 |
hidden_dim=768,
|
| 1233 |
num_heads=12,
|