asdf98 commited on
Commit
a1ff09a
·
verified ·
1 Parent(s): 4ad2cc3

Fix: streaming dataset (no full download), step-based training loop

Browse files
Files changed (1) hide show
  1. train.py +211 -104
train.py CHANGED
@@ -7,12 +7,15 @@ Flow Matching training objective (velocity prediction):
7
  - Loss: MSE(model(x_t, t), v)
8
 
9
  At inference: solve ODE from t=1 (noise) to t=0 (clean) using Euler steps.
 
 
 
10
  """
11
 
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
- from torch.utils.data import DataLoader, Dataset
16
  from torch.amp import autocast, GradScaler
17
  import math
18
  import os
@@ -37,6 +40,9 @@ class TrainConfig:
37
  dataset_config: str = ""
38
  image_column: str = "image"
39
  label_column: str = ""
 
 
 
40
 
41
  # VAE
42
  vae_id: str = "black-forest-labs/FLUX.1-schnell"
@@ -51,7 +57,7 @@ class TrainConfig:
51
  learning_rate: float = 1e-4
52
  weight_decay: float = 0.01
53
  max_grad_norm: float = 2.0
54
- num_epochs: int = 100
55
  warmup_steps: int = 1000
56
  ema_decay: float = 0.9999
57
  mixed_precision: bool = True
@@ -73,7 +79,7 @@ class TrainConfig:
73
 
74
  # System
75
  seed: int = 42
76
- num_workers: int = 2
77
  pin_memory: bool = True
78
  compile_model: bool = False
79
 
@@ -99,9 +105,114 @@ def get_model_config(size: str, num_classes: int = 0, class_drop_prob: float = 0
99
  return cfg
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  class EMAModel:
103
  """Exponential Moving Average of model parameters."""
104
-
105
  def __init__(self, model: nn.Module, decay: float = 0.9999):
106
  self.decay = decay
107
  self.shadow = {name: p.clone().detach() for name, p in model.named_parameters() if p.requires_grad}
@@ -132,41 +243,28 @@ class EMAModel:
132
 
133
 
134
  class FlowMatchingScheduler:
135
- """
136
- Flow Matching scheduler for training and sampling.
137
-
138
- Training: x_t = (1-t)*x_0 + t*ε, v_target = ε - x_0
139
- Sampling: Euler ODE from t=1 (noise) to t=0 (clean)
140
- """
141
 
142
- def __init__(self, min_t: float = 0.001, max_t: float = 0.999):
143
- self.min_t = min_t
144
- self.max_t = max_t
145
-
146
- def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
147
  return torch.rand(batch_size, device=device) * (self.max_t - self.min_t) + self.min_t
148
 
149
- def add_noise(self, x0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
150
- t_expand = t.view(-1, 1, 1, 1)
151
- return (1 - t_expand) * x0 + t_expand * noise
152
 
153
- def get_velocity_target(self, x0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
154
  return noise - x0
155
 
156
  @torch.no_grad()
157
- def sample(
158
- self, model: nn.Module, shape: tuple, device: torch.device,
159
- num_steps: int = 50, class_labels: Optional[torch.Tensor] = None,
160
- cfg_scale: float = 1.0, dtype: torch.dtype = torch.float32,
161
- ) -> torch.Tensor:
162
  model.eval()
163
  x = torch.randn(shape, device=device, dtype=dtype)
164
  dt = 1.0 / num_steps
165
- times = torch.linspace(1.0, dt, num_steps, device=device)
166
-
167
- for t_val in times:
168
  t = torch.full((shape[0],), t_val.item(), device=device, dtype=dtype)
169
-
170
  if cfg_scale > 1.0 and class_labels is not None:
171
  with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)):
172
  v_cond = model(x, t, class_labels)
@@ -175,42 +273,39 @@ class FlowMatchingScheduler:
175
  else:
176
  with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)):
177
  v = model(x, t, class_labels)
178
-
179
  x = x - dt * v
180
-
181
  return x
182
 
183
 
184
  def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
185
- """Cosine LR schedule with linear warmup."""
186
- def lr_lambda(current_step):
187
- if current_step < warmup_steps:
188
- return float(current_step) / float(max(1, warmup_steps))
189
- progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
190
  return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
191
  return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
192
 
193
 
194
  @torch.no_grad()
195
  def encode_images_with_vae(images, vae, scaling_factor, shift_factor):
196
- """Encode pixel images to VAE latents."""
197
  images = images * 2.0 - 1.0
198
  latents = vae.encode(images).latent_dist.sample()
199
- latents = (latents - shift_factor) * scaling_factor
200
- return latents
201
 
202
 
203
  @torch.no_grad()
204
  def decode_latents_with_vae(latents, vae, scaling_factor, shift_factor):
205
- """Decode VAE latents to pixel images."""
206
  latents = latents / scaling_factor + shift_factor
207
  images = vae.decode(latents).sample
208
- images = (images + 1.0) / 2.0
209
- return images.clamp(0, 1)
210
 
 
 
 
211
 
212
  def train(config: TrainConfig):
213
- """Main training loop."""
214
  from model import LiquidGen
215
 
216
  torch.manual_seed(config.seed)
@@ -224,7 +319,7 @@ def train(config: TrainConfig):
224
  with open(os.path.join(config.output_dir, "config.json"), "w") as f:
225
  json.dump(asdict(config), f, indent=2)
226
 
227
- # Load VAE
228
  print("Loading VAE...")
229
  from diffusers import AutoencoderKL
230
  vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16
@@ -233,48 +328,39 @@ def train(config: TrainConfig):
233
  ).to(device).eval()
234
  for p in vae.parameters():
235
  p.requires_grad_(False)
 
236
 
237
  # Load Dataset
238
- print(f"Loading dataset: {config.dataset_name}")
239
- from datasets import load_dataset
240
- from torchvision import transforms
241
-
242
- ds_kwargs = {}
243
- if config.dataset_config:
244
- ds_kwargs["name"] = config.dataset_config
245
- dataset = load_dataset(config.dataset_name, split="train", **ds_kwargs)
246
-
247
- transform = transforms.Compose([
248
- transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS),
249
- transforms.CenterCrop(config.image_size),
250
- transforms.RandomHorizontalFlip(),
251
- transforms.ToTensor(),
252
- ])
253
-
254
- class ImageDataset(Dataset):
255
- def __init__(self, hf_dataset, transform, image_col, label_col=""):
256
- self.dataset = hf_dataset
257
- self.transform = transform
258
- self.image_col = image_col
259
- self.label_col = label_col
260
-
261
- def __len__(self):
262
- return len(self.dataset)
263
-
264
- def __getitem__(self, idx):
265
- item = self.dataset[idx]
266
- img = item[self.image_col]
267
- if img.mode != "RGB":
268
- img = img.convert("RGB")
269
- img = self.transform(img)
270
- label = item[self.label_col] if self.label_col and self.label_col in item else -1
271
- return img, label
272
-
273
- train_dataset = ImageDataset(dataset, transform, config.image_column, config.label_column)
274
- train_loader = DataLoader(
275
- train_dataset, batch_size=config.batch_size, shuffle=True,
276
- num_workers=config.num_workers, pin_memory=config.pin_memory, drop_last=True,
277
- )
278
 
279
  # Create Model
280
  model_kwargs = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
@@ -286,30 +372,36 @@ def train(config: TrainConfig):
286
 
287
  optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate,
288
  weight_decay=config.weight_decay, betas=(0.9, 0.999))
289
- total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
290
- scheduler = get_cosine_schedule_with_warmup(optimizer, config.warmup_steps, total_steps)
291
  ema = EMAModel(model, decay=config.ema_decay)
292
  scaler = GradScaler('cuda', enabled=config.mixed_precision)
293
  fm = FlowMatchingScheduler(min_t=config.min_timestep, max_t=config.max_timestep)
294
 
295
- print(f"\nTraining: {total_steps} steps, effective batch {config.batch_size * config.gradient_accumulation_steps}")
 
296
 
 
297
  global_step = 0
298
  loss_accum = 0.0
299
-
300
- for epoch in range(config.num_epochs):
301
- model.train()
302
- t_start = time.time()
303
-
304
- for batch_idx, (images, labels) in enumerate(train_loader):
 
 
 
305
  images = images.to(device)
306
  labels = labels.to(device) if config.num_classes > 0 else None
307
 
 
308
  with torch.no_grad():
309
  latents = encode_images_with_vae(
310
  images.to(vae_dtype), vae, config.vae_scaling_factor, config.vae_shift_factor
311
  ).float()
312
 
 
313
  t = fm.sample_timesteps(latents.shape[0], device)
314
  noise = torch.randn_like(latents)
315
  x_t = fm.add_noise(latents, noise, t)
@@ -321,8 +413,9 @@ def train(config: TrainConfig):
321
 
322
  scaler.scale(loss).backward()
323
  loss_accum += loss.item()
 
324
 
325
- if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
326
  scaler.unscale_(optimizer)
327
  grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
328
  scaler.step(optimizer)
@@ -332,17 +425,22 @@ def train(config: TrainConfig):
332
  ema.update(model)
333
  global_step += 1
334
 
 
335
  if global_step % config.log_every_n_steps == 0:
336
  avg_loss = loss_accum / config.log_every_n_steps
337
  lr = optimizer.param_groups[0]["lr"]
338
- print(f"step={global_step} | epoch={epoch} | loss={avg_loss:.4f} | "
339
- f"grad_norm={grad_norm:.2f} | lr={lr:.2e}")
 
 
 
340
  loss_accum = 0.0
341
 
342
  if math.isnan(avg_loss) or avg_loss > 100:
343
  print("⚠️ Training diverged!")
344
  return
345
 
 
346
  if global_step % config.sample_every_n_steps == 0:
347
  ema.apply(model)
348
  model.eval()
@@ -356,23 +454,32 @@ def train(config: TrainConfig):
356
  config.vae_scaling_factor, config.vae_shift_factor).float()
357
  from torchvision.utils import save_image
358
  save_image(sample_imgs, os.path.join(config.output_dir, "samples", f"step_{global_step:07d}.png"), nrow=2)
 
359
  ema.restore(model)
360
  model.train()
361
 
 
362
  if global_step % config.save_every_n_steps == 0:
 
363
  torch.save({
364
  "model": model.state_dict(), "ema": ema.state_dict(),
365
  "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(),
366
- "global_step": global_step, "epoch": epoch, "config": asdict(config),
367
- }, os.path.join(config.output_dir, "checkpoints", f"step_{global_step:07d}.pt"))
368
-
369
- print(f"Epoch {epoch} complete | time={time.time()-t_start:.0f}s")
370
-
371
- torch.save({"model": model.state_dict(), "ema": ema.state_dict(), "config": asdict(config),
372
- "global_step": global_step}, os.path.join(config.output_dir, "checkpoints", "final.pt"))
373
- print(f"Training complete! Final model saved.")
 
 
 
374
 
375
 
376
  if __name__ == "__main__":
377
- config = TrainConfig(model_size="small", image_size=256, batch_size=4, num_epochs=2)
 
 
 
378
  train(config)
 
7
  - Loss: MSE(model(x_t, t), v)
8
 
9
  At inference: solve ODE from t=1 (noise) to t=0 (clean) using Euler steps.
10
+
11
+ Dataset loading: Uses STREAMING mode by default — no full download needed!
12
+ For small datasets (<500MB), set use_streaming=False for faster epoch iteration.
13
  """
14
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
19
  from torch.amp import autocast, GradScaler
20
  import math
21
  import os
 
40
  dataset_config: str = ""
41
  image_column: str = "image"
42
  label_column: str = ""
43
+ use_streaming: bool = True # KEY: streaming mode, no full download
44
+ max_samples: int = 0 # 0 = use all (only for non-streaming)
45
+ streaming_buffer: int = 1000 # Shuffle buffer for streaming
46
 
47
  # VAE
48
  vae_id: str = "black-forest-labs/FLUX.1-schnell"
 
57
  learning_rate: float = 1e-4
58
  weight_decay: float = 0.01
59
  max_grad_norm: float = 2.0
60
+ max_steps: int = 50000 # Train by steps, not epochs (better for streaming)
61
  warmup_steps: int = 1000
62
  ema_decay: float = 0.9999
63
  mixed_precision: bool = True
 
79
 
80
  # System
81
  seed: int = 42
82
+ num_workers: int = 0 # 0 for streaming (required)
83
  pin_memory: bool = True
84
  compile_model: bool = False
85
 
 
105
  return cfg
106
 
107
 
108
+ # =============================================================================
109
+ # Dataset Loaders
110
+ # =============================================================================
111
+
112
+ class StreamingImageDataset(IterableDataset):
113
+ """
114
+ Streaming dataset — loads images on-the-fly from HuggingFace Hub.
115
+ NO full download needed. Starts training immediately.
116
+
117
+ Perfect for large datasets (WikiArt, LAION, etc.) on Colab free tier.
118
+ """
119
+ def __init__(self, dataset_name, image_column="image", label_column="",
120
+ image_size=256, split="train", dataset_config="",
121
+ buffer_size=1000, seed=42):
122
+ super().__init__()
123
+ self.dataset_name = dataset_name
124
+ self.image_column = image_column
125
+ self.label_column = label_column
126
+ self.split = split
127
+ self.dataset_config = dataset_config
128
+ self.buffer_size = buffer_size
129
+ self.seed = seed
130
+
131
+ from torchvision import transforms
132
+ self.transform = transforms.Compose([
133
+ transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),
134
+ transforms.CenterCrop(image_size),
135
+ transforms.RandomHorizontalFlip(),
136
+ transforms.ToTensor(),
137
+ ])
138
+
139
+ def _get_stream(self):
140
+ from datasets import load_dataset
141
+ kwargs = {}
142
+ if self.dataset_config:
143
+ kwargs["name"] = self.dataset_config
144
+ ds = load_dataset(self.dataset_name, split=self.split, streaming=True, **kwargs)
145
+ ds = ds.shuffle(seed=self.seed, buffer_size=self.buffer_size)
146
+ return iter(ds)
147
+
148
+ def __iter__(self):
149
+ stream = self._get_stream()
150
+ for item in stream:
151
+ try:
152
+ img = item[self.image_column]
153
+ if img.mode != "RGB":
154
+ img = img.convert("RGB")
155
+ img_tensor = self.transform(img)
156
+ label = -1
157
+ if self.label_column and self.label_column in item:
158
+ label = item[self.label_column]
159
+ yield img_tensor, label
160
+ except Exception:
161
+ continue
162
+
163
+
164
+ class MapImageDataset(Dataset):
165
+ """
166
+ Standard map-style dataset for small datasets that fit in memory.
167
+ Downloads once, then fast random access.
168
+
169
+ Good for: Pokemon (95MB), Flowers (330MB), few-shot-art (510MB)
170
+ """
171
+ def __init__(self, dataset_name, image_column="image", label_column="",
172
+ image_size=256, split="train", dataset_config="", max_samples=0):
173
+ super().__init__()
174
+ self.image_column = image_column
175
+ self.label_column = label_column
176
+
177
+ from datasets import load_dataset
178
+ from torchvision import transforms
179
+
180
+ kwargs = {}
181
+ if dataset_config:
182
+ kwargs["name"] = dataset_config
183
+
184
+ print(f"Downloading {dataset_name}...")
185
+ self.dataset = load_dataset(dataset_name, split=split, **kwargs)
186
+ if max_samples > 0:
187
+ self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
188
+ print(f" {len(self.dataset)} images loaded")
189
+
190
+ self.transform = transforms.Compose([
191
+ transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),
192
+ transforms.CenterCrop(image_size),
193
+ transforms.RandomHorizontalFlip(),
194
+ transforms.ToTensor(),
195
+ ])
196
+
197
+ def __len__(self):
198
+ return len(self.dataset)
199
+
200
+ def __getitem__(self, idx):
201
+ item = self.dataset[idx]
202
+ img = item[self.image_column]
203
+ if img.mode != "RGB":
204
+ img = img.convert("RGB")
205
+ img = self.transform(img)
206
+ label = item[self.label_column] if self.label_column and self.label_column in item else -1
207
+ return img, label
208
+
209
+
210
+ # =============================================================================
211
+ # Training Utilities
212
+ # =============================================================================
213
+
214
  class EMAModel:
215
  """Exponential Moving Average of model parameters."""
 
216
  def __init__(self, model: nn.Module, decay: float = 0.9999):
217
  self.decay = decay
218
  self.shadow = {name: p.clone().detach() for name, p in model.named_parameters() if p.requires_grad}
 
243
 
244
 
245
  class FlowMatchingScheduler:
246
+ """Flow Matching: x_t = (1-t)*x_0 + t*ε, v_target = ε - x_0"""
247
+ def __init__(self, min_t=0.001, max_t=0.999):
248
+ self.min_t, self.max_t = min_t, max_t
 
 
 
249
 
250
+ def sample_timesteps(self, batch_size, device):
 
 
 
 
251
  return torch.rand(batch_size, device=device) * (self.max_t - self.min_t) + self.min_t
252
 
253
+ def add_noise(self, x0, noise, t):
254
+ t = t.view(-1, 1, 1, 1)
255
+ return (1 - t) * x0 + t * noise
256
 
257
+ def get_velocity_target(self, x0, noise):
258
  return noise - x0
259
 
260
  @torch.no_grad()
261
+ def sample(self, model, shape, device, num_steps=50, class_labels=None,
262
+ cfg_scale=1.0, dtype=torch.float32):
 
 
 
263
  model.eval()
264
  x = torch.randn(shape, device=device, dtype=dtype)
265
  dt = 1.0 / num_steps
266
+ for t_val in torch.linspace(1.0, dt, num_steps, device=device):
 
 
267
  t = torch.full((shape[0],), t_val.item(), device=device, dtype=dtype)
 
268
  if cfg_scale > 1.0 and class_labels is not None:
269
  with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)):
270
  v_cond = model(x, t, class_labels)
 
273
  else:
274
  with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)):
275
  v = model(x, t, class_labels)
 
276
  x = x - dt * v
 
277
  return x
278
 
279
 
280
  def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
281
+ def lr_lambda(step):
282
+ if step < warmup_steps:
283
+ return step / max(1, warmup_steps)
284
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
 
285
  return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
286
  return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
287
 
288
 
289
  @torch.no_grad()
290
  def encode_images_with_vae(images, vae, scaling_factor, shift_factor):
 
291
  images = images * 2.0 - 1.0
292
  latents = vae.encode(images).latent_dist.sample()
293
+ return (latents - shift_factor) * scaling_factor
 
294
 
295
 
296
  @torch.no_grad()
297
  def decode_latents_with_vae(latents, vae, scaling_factor, shift_factor):
 
298
  latents = latents / scaling_factor + shift_factor
299
  images = vae.decode(latents).sample
300
+ return ((images + 1.0) / 2.0).clamp(0, 1)
301
+
302
 
303
+ # =============================================================================
304
+ # Main Training Loop
305
+ # =============================================================================
306
 
307
  def train(config: TrainConfig):
308
+ """Main training loop with streaming dataset support."""
309
  from model import LiquidGen
310
 
311
  torch.manual_seed(config.seed)
 
319
  with open(os.path.join(config.output_dir, "config.json"), "w") as f:
320
  json.dump(asdict(config), f, indent=2)
321
 
322
+ # Load VAE (frozen)
323
  print("Loading VAE...")
324
  from diffusers import AutoencoderKL
325
  vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16
 
328
  ).to(device).eval()
329
  for p in vae.parameters():
330
  p.requires_grad_(False)
331
+ print(f"VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)")
332
 
333
  # Load Dataset
334
+ print(f"Loading dataset: {config.dataset_name} (streaming={config.use_streaming})")
335
+ if config.use_streaming:
336
+ train_dataset = StreamingImageDataset(
337
+ dataset_name=config.dataset_name,
338
+ image_column=config.image_column,
339
+ label_column=config.label_column,
340
+ image_size=config.image_size,
341
+ dataset_config=config.dataset_config,
342
+ buffer_size=config.streaming_buffer,
343
+ seed=config.seed,
344
+ )
345
+ train_loader = DataLoader(
346
+ train_dataset, batch_size=config.batch_size,
347
+ num_workers=0, # Required for streaming
348
+ pin_memory=config.pin_memory,
349
+ )
350
+ print(" Streaming mode — no full download, starts immediately!")
351
+ else:
352
+ train_dataset = MapImageDataset(
353
+ dataset_name=config.dataset_name,
354
+ image_column=config.image_column,
355
+ label_column=config.label_column,
356
+ image_size=config.image_size,
357
+ dataset_config=config.dataset_config,
358
+ max_samples=config.max_samples,
359
+ )
360
+ train_loader = DataLoader(
361
+ train_dataset, batch_size=config.batch_size, shuffle=True,
362
+ num_workers=2, pin_memory=config.pin_memory, drop_last=True,
363
+ )
 
 
 
 
 
 
 
 
 
 
364
 
365
  # Create Model
366
  model_kwargs = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
 
372
 
373
  optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate,
374
  weight_decay=config.weight_decay, betas=(0.9, 0.999))
375
+ scheduler = get_cosine_schedule_with_warmup(optimizer, config.warmup_steps, config.max_steps)
 
376
  ema = EMAModel(model, decay=config.ema_decay)
377
  scaler = GradScaler('cuda', enabled=config.mixed_precision)
378
  fm = FlowMatchingScheduler(min_t=config.min_timestep, max_t=config.max_timestep)
379
 
380
+ print(f"\nTraining for {config.max_steps} steps")
381
+ print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
382
 
383
+ # Step-based training loop (works for both streaming and map datasets)
384
  global_step = 0
385
  loss_accum = 0.0
386
+ accum_count = 0
387
+ model.train()
388
+ t_start = time.time()
389
+
390
+ while global_step < config.max_steps:
391
+ for images, labels in train_loader:
392
+ if global_step >= config.max_steps:
393
+ break
394
+
395
  images = images.to(device)
396
  labels = labels.to(device) if config.num_classes > 0 else None
397
 
398
+ # Encode to latents
399
  with torch.no_grad():
400
  latents = encode_images_with_vae(
401
  images.to(vae_dtype), vae, config.vae_scaling_factor, config.vae_shift_factor
402
  ).float()
403
 
404
+ # Flow matching
405
  t = fm.sample_timesteps(latents.shape[0], device)
406
  noise = torch.randn_like(latents)
407
  x_t = fm.add_noise(latents, noise, t)
 
413
 
414
  scaler.scale(loss).backward()
415
  loss_accum += loss.item()
416
+ accum_count += 1
417
 
418
+ if accum_count % config.gradient_accumulation_steps == 0:
419
  scaler.unscale_(optimizer)
420
  grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
421
  scaler.step(optimizer)
 
425
  ema.update(model)
426
  global_step += 1
427
 
428
+ # Logging
429
  if global_step % config.log_every_n_steps == 0:
430
  avg_loss = loss_accum / config.log_every_n_steps
431
  lr = optimizer.param_groups[0]["lr"]
432
+ elapsed = time.time() - t_start
433
+ steps_per_sec = global_step / max(elapsed, 1)
434
+ print(f"step={global_step} | loss={avg_loss:.4f} | "
435
+ f"grad_norm={grad_norm:.2f} | lr={lr:.2e} | "
436
+ f"steps/s={steps_per_sec:.2f} | elapsed={elapsed:.0f}s")
437
  loss_accum = 0.0
438
 
439
  if math.isnan(avg_loss) or avg_loss > 100:
440
  print("⚠️ Training diverged!")
441
  return
442
 
443
+ # Sample
444
  if global_step % config.sample_every_n_steps == 0:
445
  ema.apply(model)
446
  model.eval()
 
454
  config.vae_scaling_factor, config.vae_shift_factor).float()
455
  from torchvision.utils import save_image
456
  save_image(sample_imgs, os.path.join(config.output_dir, "samples", f"step_{global_step:07d}.png"), nrow=2)
457
+ print(f" 📸 Saved samples: step_{global_step:07d}.png")
458
  ema.restore(model)
459
  model.train()
460
 
461
+ # Checkpoint
462
  if global_step % config.save_every_n_steps == 0:
463
+ ckpt_path = os.path.join(config.output_dir, "checkpoints", f"step_{global_step:07d}.pt")
464
  torch.save({
465
  "model": model.state_dict(), "ema": ema.state_dict(),
466
  "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(),
467
+ "global_step": global_step, "config": asdict(config),
468
+ }, ckpt_path)
469
+ print(f" 💾 Checkpoint: {ckpt_path}")
470
+
471
+ # Final save
472
+ final_path = os.path.join(config.output_dir, "checkpoints", "final.pt")
473
+ torch.save({"model": model.state_dict(), "ema": ema.state_dict(),
474
+ "config": asdict(config), "global_step": global_step}, final_path)
475
+ elapsed = time.time() - t_start
476
+ print(f"\n🎉 Training complete! {global_step} steps in {elapsed/60:.1f} min")
477
+ print(f" Final model: {final_path}")
478
 
479
 
480
  if __name__ == "__main__":
481
+ config = TrainConfig(
482
+ model_size="small", image_size=256, batch_size=4,
483
+ max_steps=100, use_streaming=True,
484
+ )
485
  train(config)