Add ETA to every training log line + epoch summary
Browse files
train.py
CHANGED
|
@@ -6,6 +6,7 @@ Optimized for Colab free tier:
|
|
| 6 |
- Auto-limits large datasets (WikiArt capped at 10K by default)
|
| 7 |
- Latent pre-caching: train on pure tensors, no VAE during training
|
| 8 |
- Gradient checkpointing + auto batch size = no OOM
|
|
|
|
| 9 |
- All datasets pure parquet, open SDXL VAE (no login)
|
| 10 |
"""
|
| 11 |
|
|
@@ -28,7 +29,7 @@ DATASET_PRESETS = {
|
|
| 28 |
"image_column": "image",
|
| 29 |
"label_column": "",
|
| 30 |
"num_classes": 0,
|
| 31 |
-
"max_default": 0,
|
| 32 |
"description": "~2.5K cartoon/anime, unconditional, 181MB — fast",
|
| 33 |
},
|
| 34 |
"flowers": {
|
|
@@ -46,7 +47,7 @@ DATASET_PRESETS = {
|
|
| 46 |
"image_column": "image",
|
| 47 |
"label_column": "style",
|
| 48 |
"num_classes": 0,
|
| 49 |
-
"max_default": 10000,
|
| 50 |
"description": "~105K paintings with styles (auto-capped to 10K for speed)",
|
| 51 |
},
|
| 52 |
"art_painting": {
|
|
@@ -62,7 +63,6 @@ DATASET_PRESETS = {
|
|
| 62 |
|
| 63 |
|
| 64 |
def auto_batch_size(model_size, image_size, gpu_mem_gb):
|
| 65 |
-
"""Safe batch size for model + resolution + GPU."""
|
| 66 |
param_mem = {"small": 0.66, "base": 1.68, "large": 3.35}
|
| 67 |
base = param_mem.get(model_size, 1.0)
|
| 68 |
act_per_sample = {"small": {256: 0.02, 512: 0.07},
|
|
@@ -78,6 +78,13 @@ def auto_batch_size(model_size, image_size, gpu_mem_gb):
|
|
| 78 |
return max(1, bs)
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
@dataclass
|
| 82 |
class TrainConfig:
|
| 83 |
model_size: str = "small"
|
|
@@ -85,11 +92,11 @@ class TrainConfig:
|
|
| 85 |
class_drop_prob: float = 0.1
|
| 86 |
dataset_preset: str = "cartoon"
|
| 87 |
image_size: int = 256
|
| 88 |
-
max_images: int = 0
|
| 89 |
vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
|
| 90 |
vae_scaling_factor: float = 0.13025
|
| 91 |
latent_channels: int = 4
|
| 92 |
-
batch_size: int = 0
|
| 93 |
gradient_accumulation_steps: int = 1
|
| 94 |
learning_rate: float = 1e-4
|
| 95 |
weight_decay: float = 0.01
|
|
@@ -175,7 +182,6 @@ def precache_latents(config, cache_path=None):
|
|
| 175 |
transforms.CenterCrop(config.image_size), transforms.ToTensor(),
|
| 176 |
])
|
| 177 |
|
| 178 |
-
# Determine max images: user override > dataset default > all
|
| 179 |
if config.max_images > 0:
|
| 180 |
max_imgs = config.max_images
|
| 181 |
elif preset.get("max_default", 0) > 0:
|
|
@@ -184,10 +190,9 @@ def precache_latents(config, cache_path=None):
|
|
| 184 |
else:
|
| 185 |
max_imgs = len(dataset)
|
| 186 |
max_imgs = min(max_imgs, len(dataset))
|
| 187 |
-
print(f" Encoding {max_imgs} of {len(dataset)} images")
|
| 188 |
|
| 189 |
-
# VAE encode batch size: bigger = faster. 64 for 256px, 32 for 512px
|
| 190 |
encode_bs = 64 if config.image_size <= 256 else 32
|
|
|
|
| 191 |
|
| 192 |
img_col, lbl_col = preset["image_column"], preset["label_column"]
|
| 193 |
style_to_id = {}
|
|
@@ -220,7 +225,7 @@ def precache_latents(config, cache_path=None):
|
|
| 220 |
speed = count / elapsed
|
| 221 |
eta = (max_imgs - count) / speed if speed > 0 else 0
|
| 222 |
if count % (encode_bs * 4) == 0:
|
| 223 |
-
print(f" {count}/{max_imgs}
|
| 224 |
|
| 225 |
if batch_px:
|
| 226 |
with torch.no_grad():
|
|
@@ -237,8 +242,7 @@ def precache_latents(config, cache_path=None):
|
|
| 237 |
print(f" {len(style_to_id)} style classes")
|
| 238 |
torch.save(save_data, cache_path)
|
| 239 |
mb = os.path.getsize(cache_path) / 1024**2
|
| 240 |
-
|
| 241 |
-
print(f"Cached {count} latents -> {cache_path} ({mb:.0f}MB, {elapsed:.0f}s)")
|
| 242 |
del vae
|
| 243 |
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 244 |
return cache_path
|
|
@@ -330,7 +334,7 @@ def train(config):
|
|
| 330 |
scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available())
|
| 331 |
fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
|
| 332 |
lat_size = config.image_size // 8
|
| 333 |
-
print(f"Steps: {total_steps}
|
| 334 |
|
| 335 |
gs = 0; la = 0.0; vae = None; vae_loaded = False
|
| 336 |
print(f"\nTraining!\n")
|
|
@@ -355,10 +359,14 @@ def train(config):
|
|
| 355 |
ema.update(model); gs += 1
|
| 356 |
if gs % config.log_every_n_steps == 0:
|
| 357 |
al = la / config.log_every_n_steps
|
|
|
|
|
|
|
|
|
|
| 358 |
vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
f"{
|
|
|
|
| 362 |
la = 0.0
|
| 363 |
if math.isnan(al) or al > 50: print("Diverged!"); return
|
| 364 |
if gs % config.sample_every_n_steps == 0:
|
|
@@ -380,8 +388,11 @@ def train(config):
|
|
| 380 |
torch.save({"model": model.state_dict(), "ema": ema.shadow,
|
| 381 |
"optimizer": opt.state_dict(), "step": gs, "model_config": mcfg},
|
| 382 |
f"{config.output_dir}/checkpoints/step_{gs:07d}.pt")
|
| 383 |
-
|
|
|
|
|
|
|
| 384 |
|
| 385 |
final = f"{config.output_dir}/checkpoints/final.pt"
|
| 386 |
torch.save({"model": model.state_dict(), "ema": ema.shadow, "model_config": mcfg, "step": gs}, final)
|
| 387 |
-
|
|
|
|
|
|
| 6 |
- Auto-limits large datasets (WikiArt capped at 10K by default)
|
| 7 |
- Latent pre-caching: train on pure tensors, no VAE during training
|
| 8 |
- Gradient checkpointing + auto batch size = no OOM
|
| 9 |
+
- ETA shown on every log line
|
| 10 |
- All datasets pure parquet, open SDXL VAE (no login)
|
| 11 |
"""
|
| 12 |
|
|
|
|
| 29 |
"image_column": "image",
|
| 30 |
"label_column": "",
|
| 31 |
"num_classes": 0,
|
| 32 |
+
"max_default": 0,
|
| 33 |
"description": "~2.5K cartoon/anime, unconditional, 181MB — fast",
|
| 34 |
},
|
| 35 |
"flowers": {
|
|
|
|
| 47 |
"image_column": "image",
|
| 48 |
"label_column": "style",
|
| 49 |
"num_classes": 0,
|
| 50 |
+
"max_default": 10000,
|
| 51 |
"description": "~105K paintings with styles (auto-capped to 10K for speed)",
|
| 52 |
},
|
| 53 |
"art_painting": {
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
def auto_batch_size(model_size, image_size, gpu_mem_gb):
|
|
|
|
| 66 |
param_mem = {"small": 0.66, "base": 1.68, "large": 3.35}
|
| 67 |
base = param_mem.get(model_size, 1.0)
|
| 68 |
act_per_sample = {"small": {256: 0.02, 512: 0.07},
|
|
|
|
| 78 |
return max(1, bs)
|
| 79 |
|
| 80 |
|
| 81 |
+
def _fmt_time(seconds):
|
| 82 |
+
"""Format seconds into human readable string."""
|
| 83 |
+
if seconds < 60: return f"{seconds:.0f}s"
|
| 84 |
+
if seconds < 3600: return f"{seconds/60:.1f}m"
|
| 85 |
+
return f"{seconds/3600:.1f}h"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
@dataclass
|
| 89 |
class TrainConfig:
|
| 90 |
model_size: str = "small"
|
|
|
|
| 92 |
class_drop_prob: float = 0.1
|
| 93 |
dataset_preset: str = "cartoon"
|
| 94 |
image_size: int = 256
|
| 95 |
+
max_images: int = 0
|
| 96 |
vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
|
| 97 |
vae_scaling_factor: float = 0.13025
|
| 98 |
latent_channels: int = 4
|
| 99 |
+
batch_size: int = 0
|
| 100 |
gradient_accumulation_steps: int = 1
|
| 101 |
learning_rate: float = 1e-4
|
| 102 |
weight_decay: float = 0.01
|
|
|
|
| 182 |
transforms.CenterCrop(config.image_size), transforms.ToTensor(),
|
| 183 |
])
|
| 184 |
|
|
|
|
| 185 |
if config.max_images > 0:
|
| 186 |
max_imgs = config.max_images
|
| 187 |
elif preset.get("max_default", 0) > 0:
|
|
|
|
| 190 |
else:
|
| 191 |
max_imgs = len(dataset)
|
| 192 |
max_imgs = min(max_imgs, len(dataset))
|
|
|
|
| 193 |
|
|
|
|
| 194 |
encode_bs = 64 if config.image_size <= 256 else 32
|
| 195 |
+
print(f" Encoding {max_imgs} images (batch={encode_bs})...")
|
| 196 |
|
| 197 |
img_col, lbl_col = preset["image_column"], preset["label_column"]
|
| 198 |
style_to_id = {}
|
|
|
|
| 225 |
speed = count / elapsed
|
| 226 |
eta = (max_imgs - count) / speed if speed > 0 else 0
|
| 227 |
if count % (encode_bs * 4) == 0:
|
| 228 |
+
print(f" {count}/{max_imgs} | {speed:.0f} img/s | ETA {_fmt_time(eta)}")
|
| 229 |
|
| 230 |
if batch_px:
|
| 231 |
with torch.no_grad():
|
|
|
|
| 242 |
print(f" {len(style_to_id)} style classes")
|
| 243 |
torch.save(save_data, cache_path)
|
| 244 |
mb = os.path.getsize(cache_path) / 1024**2
|
| 245 |
+
print(f"Cached {count} latents -> {cache_path} ({mb:.0f}MB, {_fmt_time(time.time()-t0)})")
|
|
|
|
| 246 |
del vae
|
| 247 |
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 248 |
return cache_path
|
|
|
|
| 334 |
scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available())
|
| 335 |
fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
|
| 336 |
lat_size = config.image_size // 8
|
| 337 |
+
print(f"Steps: {total_steps} | Batch: {config.batch_size} | Epochs: {config.num_epochs}")
|
| 338 |
|
| 339 |
gs = 0; la = 0.0; vae = None; vae_loaded = False
|
| 340 |
print(f"\nTraining!\n")
|
|
|
|
| 359 |
ema.update(model); gs += 1
|
| 360 |
if gs % config.log_every_n_steps == 0:
|
| 361 |
al = la / config.log_every_n_steps
|
| 362 |
+
elapsed = time.time() - t_start
|
| 363 |
+
sps = gs / max(elapsed, 1)
|
| 364 |
+
remaining = (total_steps - gs) / sps if sps > 0 else 0
|
| 365 |
vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
|
| 366 |
+
pct = gs / total_steps * 100
|
| 367 |
+
print(f"step={gs:>6d}/{total_steps} ({pct:.0f}%) | ep={epoch} | "
|
| 368 |
+
f"loss={al:.4f} | gn={gn:.2f} | lr={opt.param_groups[0]['lr']:.2e} | "
|
| 369 |
+
f"vram={vram:.1f}G | {sps:.1f} st/s | ETA {_fmt_time(remaining)}")
|
| 370 |
la = 0.0
|
| 371 |
if math.isnan(al) or al > 50: print("Diverged!"); return
|
| 372 |
if gs % config.sample_every_n_steps == 0:
|
|
|
|
| 388 |
torch.save({"model": model.state_dict(), "ema": ema.shadow,
|
| 389 |
"optimizer": opt.state_dict(), "step": gs, "model_config": mcfg},
|
| 390 |
f"{config.output_dir}/checkpoints/step_{gs:07d}.pt")
|
| 391 |
+
ep_time = time.time() - et
|
| 392 |
+
ep_eta = ep_time * (config.num_epochs - epoch - 1)
|
| 393 |
+
print(f"Epoch {epoch}/{config.num_epochs} done | {_fmt_time(ep_time)} | ETA {_fmt_time(ep_eta)}\n")
|
| 394 |
|
| 395 |
final = f"{config.output_dir}/checkpoints/final.pt"
|
| 396 |
torch.save({"model": model.state_dict(), "ema": ema.shadow, "model_config": mcfg, "step": gs}, final)
|
| 397 |
+
total_time = time.time() - t_start
|
| 398 |
+
print(f"\nDone! {gs} steps in {_fmt_time(total_time)} -> {final}")
|