asdf98 commited on
Commit
0b46772
·
verified ·
1 Parent(s): 2403335

Fix dataset: drop broken keremberke, use pure-parquet datasets only (cartoon default, Artificio/WikiArt for styles)"

Browse files
Files changed (1) hide show
  1. train.py +58 -65
train.py CHANGED
@@ -3,14 +3,13 @@ LiquidGen Training Pipeline v2
3
 
4
  Optimized for Colab free tier:
5
  - Latent pre-caching: encode images with VAE once, save to disk, train on pure tensors
6
- - No VAE needed during training loop saves ~1GB VRAM + faster iterations
7
- - Streaming support for large datasets
8
- - Multiple small dataset presets
9
  - Uses madebyollin/sdxl-vae-fp16-fix (fully open, no login, fp16 stable)
10
 
11
  Flow Matching training objective (velocity prediction):
12
- - Forward: x_t = (1 - t) * x_0 + t * ε
13
- - Target: v = ε - x_0
14
  - Loss: MSE(model(x_t, t), v)
15
  """
16
 
@@ -28,35 +27,17 @@ from dataclasses import dataclass, asdict
28
 
29
 
30
  # =============================================================================
31
- # Dataset Presets (all verified, fast to download, no auth needed)
32
  # =============================================================================
33
 
34
  DATASET_PRESETS = {
35
- "paintings_mini": {
36
- "name": "keremberke/painting-style-classification",
37
- "config": "mini",
38
- "image_column": "image",
39
- "label_column": "labels",
40
- "num_classes": 27,
41
- "trust_remote_code": True,
42
- "description": "~200 painting samples, 27 styles, 1.7MB — instant smoke test",
43
- },
44
- "paintings": {
45
- "name": "keremberke/painting-style-classification",
46
- "config": "full",
47
- "image_column": "image",
48
- "label_column": "labels",
49
- "num_classes": 27,
50
- "trust_remote_code": True,
51
- "description": "~8K paintings, 27 styles, 204MB — best for style-conditional training",
52
- },
53
  "cartoon": {
54
  "name": "Norod78/cartoon-blip-captions",
55
  "config": "",
56
  "image_column": "image",
57
  "label_column": "",
58
  "num_classes": 0,
59
- "description": "~2.5K cartoon/anime, unconditional, 181MB",
60
  },
61
  "flowers": {
62
  "name": "huggan/flowers-102-categories",
@@ -66,14 +47,21 @@ DATASET_PRESETS = {
66
  "num_classes": 0,
67
  "description": "~8K flower photos, unconditional, 331MB",
68
  },
69
- "wikiart_stream": {
70
- "name": "huggan/wikiart",
71
  "config": "",
72
  "image_column": "image",
73
  "label_column": "style",
74
- "num_classes": 27,
75
- "streaming": True,
76
- "description": "~80K paintings, 27 styles, STREAMING (0 disk) — use max_images to limit",
 
 
 
 
 
 
 
77
  },
78
  }
79
 
@@ -83,13 +71,13 @@ class TrainConfig:
83
  """Training configuration optimized for Colab free tier (T4 16GB)."""
84
  # Model
85
  model_size: str = "small" # small (~55M), base (~140M), large (~280M)
86
- num_classes: int = 27
87
  class_drop_prob: float = 0.1
88
 
89
  # Data
90
- dataset_preset: str = "paintings" # key from DATASET_PRESETS
91
  image_size: int = 256 # 256 or 512
92
- max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
93
 
94
  # VAE — fully open, no login needed
95
  vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
@@ -161,8 +149,8 @@ class CachedLatentDataset(Dataset):
161
  self.labels = data.get("labels", None)
162
  print(f"Loaded {len(self.latents)} cached latents from {cache_path}")
163
  print(f" Shape: {self.latents.shape}, dtype: {self.latents.dtype}")
164
- if self.labels is not None:
165
- print(f" Labels: unique={self.labels.unique().shape[0]}")
166
 
167
  def __len__(self):
168
  return len(self.latents)
@@ -176,46 +164,39 @@ class CachedLatentDataset(Dataset):
176
  def precache_latents(config, cache_path=None):
177
  """
178
  Encode all images to VAE latents once, save to disk.
179
- Uses madebyollin/sdxl-vae-fp16-fix (no auth needed).
180
  """
181
  if cache_path is None:
182
  cache_path = os.path.join(config.output_dir, "cached_latents.pt")
183
 
184
  if os.path.exists(cache_path):
185
- print(f"Cache exists: {cache_path}")
186
  data = torch.load(cache_path, map_location="cpu", weights_only=True)
187
- print(f" {data['latents'].shape[0]} latents, shape {data['latents'].shape[1:]}")
188
  return cache_path
189
 
190
  os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
191
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
192
 
193
  # Load VAE
194
- print(f"Loading VAE: {config.vae_id} (open, no login needed)...")
195
  from diffusers import AutoencoderKL
196
  vae = AutoencoderKL.from_pretrained(
197
  config.vae_id, torch_dtype=torch.float16
198
  ).to(device).eval()
199
  for p in vae.parameters():
200
  p.requires_grad_(False)
201
- print(f" VAE loaded: {sum(p.numel() for p in vae.parameters())/1e6:.0f}M params")
202
 
203
  # Load dataset
204
  preset = DATASET_PRESETS[config.dataset_preset]
205
- print(f"Loading dataset: {preset['name']} ({preset['description']})")
206
 
207
  from datasets import load_dataset
208
  from torchvision import transforms
209
 
210
- is_streaming = preset.get("streaming", False)
211
  ds_kwargs = {"split": "train"}
212
  if preset["config"]:
213
  ds_kwargs["name"] = preset["config"]
214
- if is_streaming:
215
- ds_kwargs["streaming"] = True
216
- # Some datasets have legacy loading scripts that need this flag
217
- if preset.get("trust_remote_code", False):
218
- ds_kwargs["trust_remote_code"] = True
219
 
220
  dataset = load_dataset(preset["name"], **ds_kwargs)
221
 
@@ -225,6 +206,11 @@ def precache_latents(config, cache_path=None):
225
  transforms.ToTensor(),
226
  ])
227
 
 
 
 
 
 
228
  all_latents = []
229
  all_labels = []
230
  batch_pixels = []
@@ -232,10 +218,8 @@ def precache_latents(config, cache_path=None):
232
  encode_bs = 16
233
  count = 0
234
  max_imgs = config.max_images if config.max_images > 0 else float("inf")
235
- img_col = preset["image_column"]
236
- lbl_col = preset["label_column"]
237
 
238
- print(f"Encoding images to VAE latents...")
239
  t0 = time.time()
240
 
241
  for item in dataset:
@@ -245,8 +229,18 @@ def precache_latents(config, cache_path=None):
245
  if img.mode != "RGB":
246
  img = img.convert("RGB")
247
  batch_pixels.append(transform(img))
 
 
248
  if lbl_col and lbl_col in item:
249
- batch_labels.append(item[lbl_col])
 
 
 
 
 
 
 
 
250
  else:
251
  batch_labels.append(-1)
252
  count += 1
@@ -260,7 +254,7 @@ def precache_latents(config, cache_path=None):
260
  all_labels.extend(batch_labels)
261
  batch_pixels, batch_labels = [], []
262
  if count % 500 == 0:
263
- print(f" {count} images encoded ({time.time()-t0:.0f}s)")
264
 
265
  if batch_pixels:
266
  with torch.no_grad():
@@ -272,17 +266,22 @@ def precache_latents(config, cache_path=None):
272
 
273
  all_latents = torch.cat(all_latents, dim=0)
274
  all_labels = torch.tensor(all_labels, dtype=torch.long)
275
- torch.save({"latents": all_latents, "labels": all_labels}, cache_path)
 
 
 
 
 
276
 
277
  elapsed = time.time() - t0
278
  mb = os.path.getsize(cache_path) / 1024**2
279
- print(f"\n✅ Cached {count} latents -> {cache_path}")
280
- print(f" Shape: {all_latents.shape}, Size: {mb:.1f}MB, Time: {elapsed:.0f}s")
281
 
282
  del vae
283
  if torch.cuda.is_available():
284
  torch.cuda.empty_cache()
285
- print(" VAE unloaded, VRAM freed\n")
286
  return cache_path
287
 
288
 
@@ -371,15 +370,12 @@ def train(config):
371
  with open(f"{config.output_dir}/config.json", "w") as f:
372
  json.dump(asdict(config), f, indent=2)
373
 
374
- # Step 1: Pre-cache latents
375
  cache_path = precache_latents(config)
376
 
377
- # Step 2: Dataset from cache
378
  train_ds = CachedLatentDataset(cache_path)
379
  train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
380
  num_workers=config.num_workers, pin_memory=True, drop_last=True)
381
 
382
- # Step 3: Model
383
  mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
384
  mcfg["in_channels"] = config.latent_channels
385
  model = LiquidGen(**mcfg).to(device)
@@ -388,7 +384,6 @@ def train(config):
388
  if config.compile_model and hasattr(torch, "compile"):
389
  model = torch.compile(model)
390
 
391
- # Step 4: Training setup
392
  opt = torch.optim.AdamW(model.parameters(), lr=config.learning_rate,
393
  weight_decay=config.weight_decay, betas=(0.9, 0.999))
394
  total_steps = len(train_dl) * config.num_epochs // config.gradient_accumulation_steps
@@ -398,16 +393,14 @@ def train(config):
398
  fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
399
  lat_size = config.image_size // 8
400
 
401
- print(f"\nTotal steps: {total_steps}, Batch: {config.batch_size}x{config.gradient_accumulation_steps}")
402
  print(f"Latent: [{config.batch_size}, {config.latent_channels}, {lat_size}, {lat_size}]")
403
- print(f"No VAE during training -> max VRAM for model")
404
  if torch.cuda.is_available():
405
  print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / "
406
  f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")
407
 
408
- # Step 5: Train!
409
  gs = 0; la = 0.0; vae = None; vae_loaded = False
410
- print(f"\n{'='*60}\n🚀 Training!\n{'='*60}\n")
411
  t_start = time.time()
412
 
413
  for epoch in range(config.num_epochs):
@@ -483,7 +476,7 @@ def train(config):
483
 
484
  if __name__ == "__main__":
485
  config = TrainConfig(
486
- model_size="small", dataset_preset="paintings_mini",
487
  image_size=256, batch_size=8, num_epochs=5,
488
  log_every_n_steps=5, sample_every_n_steps=99999,
489
  )
 
3
 
4
  Optimized for Colab free tier:
5
  - Latent pre-caching: encode images with VAE once, save to disk, train on pure tensors
6
+ - No VAE needed during training loop -> saves ~1GB VRAM + faster iterations
7
+ - All datasets are pure parquet — no legacy loading scripts
 
8
  - Uses madebyollin/sdxl-vae-fp16-fix (fully open, no login, fp16 stable)
9
 
10
  Flow Matching training objective (velocity prediction):
11
+ - Forward: x_t = (1 - t) * x_0 + t * eps
12
+ - Target: v = eps - x_0
13
  - Loss: MSE(model(x_t, t), v)
14
  """
15
 
 
27
 
28
 
29
  # =============================================================================
30
+ # Dataset Presets ALL pure parquet, no loading scripts, no auth
31
  # =============================================================================
32
 
33
  DATASET_PRESETS = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  "cartoon": {
35
  "name": "Norod78/cartoon-blip-captions",
36
  "config": "",
37
  "image_column": "image",
38
  "label_column": "",
39
  "num_classes": 0,
40
+ "description": "~2.5K cartoon/anime images, unconditional, 181MB",
41
  },
42
  "flowers": {
43
  "name": "huggan/flowers-102-categories",
 
47
  "num_classes": 0,
48
  "description": "~8K flower photos, unconditional, 331MB",
49
  },
50
+ "wikiart": {
51
+ "name": "Artificio/WikiArt",
52
  "config": "",
53
  "image_column": "image",
54
  "label_column": "style",
55
+ "num_classes": 0, # string labels, mapped to ints automatically
56
+ "description": "~105K paintings with style labels, 1.6GB (use max_images to limit)",
57
+ },
58
+ "art_painting": {
59
+ "name": "huggan/few-shot-art-painting",
60
+ "config": "",
61
+ "image_column": "image",
62
+ "label_column": "",
63
+ "num_classes": 0,
64
+ "description": "~6K art paintings, unconditional, 511MB",
65
  },
66
  }
67
 
 
71
  """Training configuration optimized for Colab free tier (T4 16GB)."""
72
  # Model
73
  model_size: str = "small" # small (~55M), base (~140M), large (~280M)
74
+ num_classes: int = 0 # 0 = unconditional
75
  class_drop_prob: float = 0.1
76
 
77
  # Data
78
+ dataset_preset: str = "cartoon" # key from DATASET_PRESETS
79
  image_size: int = 256 # 256 or 512
80
+ max_images: int = 0 # 0 = use all, >0 = limit
81
 
82
  # VAE — fully open, no login needed
83
  vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
 
149
  self.labels = data.get("labels", None)
150
  print(f"Loaded {len(self.latents)} cached latents from {cache_path}")
151
  print(f" Shape: {self.latents.shape}, dtype: {self.latents.dtype}")
152
+ if self.labels is not None and (self.labels >= 0).any():
153
+ print(f" Labels: unique={self.labels[self.labels >= 0].unique().shape[0]}")
154
 
155
  def __len__(self):
156
  return len(self.latents)
 
164
  def precache_latents(config, cache_path=None):
165
  """
166
  Encode all images to VAE latents once, save to disk.
 
167
  """
168
  if cache_path is None:
169
  cache_path = os.path.join(config.output_dir, "cached_latents.pt")
170
 
171
  if os.path.exists(cache_path):
172
+ print(f"Cache exists: {cache_path}")
173
  data = torch.load(cache_path, map_location="cpu", weights_only=True)
174
+ print(f" {data['latents'].shape[0]} latents, shape {data['latents'].shape[1:]}")
175
  return cache_path
176
 
177
  os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
178
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179
 
180
  # Load VAE
181
+ print(f"Loading VAE: {config.vae_id}...")
182
  from diffusers import AutoencoderKL
183
  vae = AutoencoderKL.from_pretrained(
184
  config.vae_id, torch_dtype=torch.float16
185
  ).to(device).eval()
186
  for p in vae.parameters():
187
  p.requires_grad_(False)
188
+ print(f" VAE: {sum(p.numel() for p in vae.parameters())/1e6:.0f}M params")
189
 
190
  # Load dataset
191
  preset = DATASET_PRESETS[config.dataset_preset]
192
+ print(f"Loading: {preset['name']} ({preset['description']})")
193
 
194
  from datasets import load_dataset
195
  from torchvision import transforms
196
 
 
197
  ds_kwargs = {"split": "train"}
198
  if preset["config"]:
199
  ds_kwargs["name"] = preset["config"]
 
 
 
 
 
200
 
201
  dataset = load_dataset(preset["name"], **ds_kwargs)
202
 
 
206
  transforms.ToTensor(),
207
  ])
208
 
209
+ # For Artificio/WikiArt: style is a string, map to int
210
+ img_col = preset["image_column"]
211
+ lbl_col = preset["label_column"]
212
+ style_to_id = {}
213
+
214
  all_latents = []
215
  all_labels = []
216
  batch_pixels = []
 
218
  encode_bs = 16
219
  count = 0
220
  max_imgs = config.max_images if config.max_images > 0 else float("inf")
 
 
221
 
222
+ print(f"Encoding to VAE latents...")
223
  t0 = time.time()
224
 
225
  for item in dataset:
 
229
  if img.mode != "RGB":
230
  img = img.convert("RGB")
231
  batch_pixels.append(transform(img))
232
+
233
+ # Handle labels: int or string
234
  if lbl_col and lbl_col in item:
235
+ raw_label = item[lbl_col]
236
+ if isinstance(raw_label, str):
237
+ if raw_label not in style_to_id:
238
+ style_to_id[raw_label] = len(style_to_id)
239
+ batch_labels.append(style_to_id[raw_label])
240
+ elif isinstance(raw_label, int):
241
+ batch_labels.append(raw_label)
242
+ else:
243
+ batch_labels.append(-1)
244
  else:
245
  batch_labels.append(-1)
246
  count += 1
 
254
  all_labels.extend(batch_labels)
255
  batch_pixels, batch_labels = [], []
256
  if count % 500 == 0:
257
+ print(f" {count} images ({time.time()-t0:.0f}s)")
258
 
259
  if batch_pixels:
260
  with torch.no_grad():
 
266
 
267
  all_latents = torch.cat(all_latents, dim=0)
268
  all_labels = torch.tensor(all_labels, dtype=torch.long)
269
+
270
+ save_data = {"latents": all_latents, "labels": all_labels}
271
+ if style_to_id:
272
+ save_data["style_to_id"] = style_to_id
273
+ print(f" Mapped {len(style_to_id)} style labels to class IDs")
274
+ torch.save(save_data, cache_path)
275
 
276
  elapsed = time.time() - t0
277
  mb = os.path.getsize(cache_path) / 1024**2
278
+ print(f"\nCached {count} latents -> {cache_path}")
279
+ print(f" Shape: {all_latents.shape}, {mb:.1f}MB, {elapsed:.0f}s")
280
 
281
  del vae
282
  if torch.cuda.is_available():
283
  torch.cuda.empty_cache()
284
+ print(" VAE unloaded\n")
285
  return cache_path
286
 
287
 
 
370
  with open(f"{config.output_dir}/config.json", "w") as f:
371
  json.dump(asdict(config), f, indent=2)
372
 
 
373
  cache_path = precache_latents(config)
374
 
 
375
  train_ds = CachedLatentDataset(cache_path)
376
  train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
377
  num_workers=config.num_workers, pin_memory=True, drop_last=True)
378
 
 
379
  mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
380
  mcfg["in_channels"] = config.latent_channels
381
  model = LiquidGen(**mcfg).to(device)
 
384
  if config.compile_model and hasattr(torch, "compile"):
385
  model = torch.compile(model)
386
 
 
387
  opt = torch.optim.AdamW(model.parameters(), lr=config.learning_rate,
388
  weight_decay=config.weight_decay, betas=(0.9, 0.999))
389
  total_steps = len(train_dl) * config.num_epochs // config.gradient_accumulation_steps
 
393
  fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
394
  lat_size = config.image_size // 8
395
 
396
+ print(f"\nSteps: {total_steps}, Batch: {config.batch_size}x{config.gradient_accumulation_steps}")
397
  print(f"Latent: [{config.batch_size}, {config.latent_channels}, {lat_size}, {lat_size}]")
 
398
  if torch.cuda.is_available():
399
  print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / "
400
  f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")
401
 
 
402
  gs = 0; la = 0.0; vae = None; vae_loaded = False
403
+ print(f"\n{'='*60}\nTraining!\n{'='*60}\n")
404
  t_start = time.time()
405
 
406
  for epoch in range(config.num_epochs):
 
476
 
477
  if __name__ == "__main__":
478
  config = TrainConfig(
479
+ model_size="small", dataset_preset="cartoon",
480
  image_size=256, batch_size=8, num_epochs=5,
481
  log_every_n_steps=5, sample_every_n_steps=99999,
482
  )