AbstractPhil commited on
Commit
61d1a19
·
verified ·
1 Parent(s): 29692b5

Update trainer_colab.py

Browse files
Files changed (1) hide show
  1. trainer_colab.py +334 -454
trainer_colab.py CHANGED
@@ -1,9 +1,15 @@
1
  # ============================================================================
2
- # TinyFlux Training Cell - Full Featured
3
  # ============================================================================
4
- # Run the model cell before this one (defines TinyFlux, TinyFluxConfig)
5
- # Dataset: AbstractPhil/flux-schnell-teacher-latents
6
- # Uploads checkpoints to: AbstractPhil/tiny-flux
 
 
 
 
 
 
7
  # ============================================================================
8
 
9
  import torch
@@ -21,11 +27,24 @@ import os
21
  import json
22
  from datetime import datetime
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # ============================================================================
25
  # CONFIG
26
  # ============================================================================
27
- BATCH_SIZE = 4
28
- GRAD_ACCUM = 2
29
  LR = 1e-4
30
  EPOCHS = 10
31
  MAX_SEQ = 128
@@ -36,26 +55,16 @@ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
36
 
37
  # HuggingFace Hub
38
  HF_REPO = "AbstractPhil/tiny-flux"
39
- SAVE_EVERY = 1000 # steps - local save
40
- UPLOAD_EVERY = 1000 # steps - hub upload
41
- SAMPLE_EVERY = 500 # steps - generate samples
42
- LOG_EVERY = 10 # steps - tensorboard
43
-
44
- # Checkpoint loading target
45
- # Options:
46
- # None or "latest" - load most recent checkpoint
47
- # "best" - load best model
48
- # int (e.g. 1500) - load specific step
49
- # "hub:step_1000" - load specific checkpoint from hub
50
- # "local:path/to/checkpoint.safetensors" or "local:path/to/checkpoint.pt"
51
- # "none" - start fresh, ignore existing checkpoints
52
- LOAD_TARGET = "latest"
53
-
54
- # Manual resume step (set to override step from checkpoint, or None to use checkpoint's step)
55
- # Useful when checkpoint doesn't contain step info
56
- RESUME_STEP = None # e.g., 5000 to resume from step 5000
57
-
58
- # Local paths
59
  CHECKPOINT_DIR = "./tiny_flux_checkpoints"
60
  LOG_DIR = "./tiny_flux_logs"
61
  SAMPLE_DIR = "./tiny_flux_samples"
@@ -69,7 +78,6 @@ os.makedirs(SAMPLE_DIR, exist_ok=True)
69
  # ============================================================================
70
  print("Setting up HuggingFace Hub...")
71
  api = HfApi()
72
-
73
  try:
74
  api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model")
75
  print(f"✓ Repo ready: {HF_REPO}")
@@ -87,7 +95,7 @@ print(f"✓ Tensorboard: {LOG_DIR}/{run_name}")
87
  # LOAD DATASET
88
  # ============================================================================
89
  print("\nLoading dataset...")
90
- ds = load_dataset("AbstractPhil/flux-schnell-teacher-latents", split="train")
91
  print(f"Samples: {len(ds)}")
92
 
93
  # ============================================================================
@@ -109,104 +117,151 @@ for p in clip_enc.parameters(): p.requires_grad = False
109
  # ============================================================================
110
  print("Loading Flux VAE for samples...")
111
  from diffusers import AutoencoderKL
 
112
  vae = AutoencoderKL.from_pretrained(
113
- "black-forest-labs/FLUX.1-schnell",
114
  subfolder="vae",
115
  torch_dtype=DTYPE
116
  ).to(DEVICE).eval()
117
  for p in vae.parameters(): p.requires_grad = False
118
 
119
  # ============================================================================
120
- # ENCODING HELPERS
121
- # ============================================================================
122
- @torch.no_grad()
123
- def encode_prompt(prompt):
124
- t5_in = t5_tok(prompt, max_length=MAX_SEQ, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
125
- t5_out = t5_enc(input_ids=t5_in.input_ids, attention_mask=t5_in.attention_mask).last_hidden_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- clip_in = clip_tok(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE)
128
- clip_out = clip_enc(input_ids=clip_in.input_ids, attention_mask=clip_in.attention_mask)
129
  return t5_out, clip_out.pooler_output
130
 
 
 
 
 
 
 
 
131
  # ============================================================================
132
- # FLOW MATCHING HELPERS
133
- # ============================================================================
134
- # Rectified Flow / Flow Matching formulation:
135
- # x_t = (1-t) * x_0 + t * x_1
136
- # where x_0 = noise, x_1 = data
137
- # t=0: pure noise, t=1: pure data
138
- # velocity v = x_1 - x_0 = data - noise
139
- #
140
- # Training: model learns to predict v given (x_t, t)
141
- # Inference: start from noise (t=0), integrate to data (t=1)
142
- # x_{t+dt} = x_t + v_pred * dt
143
  # ============================================================================
 
 
 
 
144
 
145
- def flux_shift(t, s=SHIFT):
146
- """Flux timestep shift for training distribution.
147
-
148
- Shifts timesteps towards higher values (closer to data),
149
- making training focus more on refining details.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- s=3.0 (default): flux_shift(0.5) ≈ 0.75
152
- """
 
 
 
 
 
 
 
153
  return s * t / (1 + (s - 1) * t)
154
 
155
- def flux_shift_inverse(t_shifted, s=SHIFT):
156
- """Inverse of flux_shift."""
157
- return t_shifted / (s - (s - 1) * t_shifted)
158
 
159
  def min_snr_weight(t, gamma=MIN_SNR):
160
- """Min-SNR weighting to balance loss across timesteps.
161
-
162
- Downweights very easy timesteps (near t=0 or t=1).
163
- gamma=5.0 is typical.
164
- """
165
  snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
166
  return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
167
 
 
168
  # ============================================================================
169
- # SAMPLING FUNCTION
170
  # ============================================================================
171
- @torch.no_grad()
172
  def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
173
- """Generate sample images using Euler sampling.
174
-
175
- Flow matching: x_t = (1-t)*noise + t*data, v = data - noise
176
- At t=0: pure noise. At t=1: pure data.
177
- We integrate from t=0 to t=1.
178
- """
179
  model.eval()
180
  B = len(prompts)
181
- C = 16 # VAE channels
182
 
183
- # Encode prompts
184
- t5_embeds, clip_pooleds = [], []
185
- for p in prompts:
186
- t5_out, clip_pooled = encode_prompt(p)
187
- t5_embeds.append(t5_out.squeeze(0))
188
- clip_pooleds.append(clip_pooled.squeeze(0))
189
- t5_embeds = torch.stack(t5_embeds)
190
- clip_pooleds = torch.stack(clip_pooleds)
191
 
192
- # Start from pure noise (t=0)
193
  x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
194
 
195
- # Create image IDs
196
  img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
197
 
198
- # Euler sampling: t goes from 0 (noise) to 1 (data)
199
- timesteps = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
 
200
 
 
201
  for i in range(num_steps):
202
  t_curr = timesteps[i]
203
  t_next = timesteps[i + 1]
204
- dt = t_next - t_curr # positive
205
 
206
- t_batch = t_curr.expand(B)
207
-
208
- # Conditional prediction
209
  guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
 
210
  v_cond = model(
211
  hidden_states=x,
212
  encoder_hidden_states=t5_embeds,
@@ -216,13 +271,10 @@ def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=6
216
  guidance=guidance,
217
  )
218
 
219
- # Euler step: x_{t+dt} = x_t + v * dt
220
  x = x + v_cond * dt
221
 
222
- # Reshape to image format: (B, H*W, C) -> (B, C, H, W)
223
  latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
224
-
225
- # Decode with VAE (match VAE dtype)
226
  latents = latents / vae.config.scaling_factor
227
  images = vae.decode(latents.to(vae.dtype)).sample
228
  images = (images / 2 + 0.5).clamp(0, 1)
@@ -230,94 +282,133 @@ def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=6
230
  model.train()
231
  return images
232
 
 
233
  def save_samples(images, prompts, step, save_dir):
234
- """Save sample images and log to tensorboard."""
235
  from torchvision.utils import make_grid, save_image
236
 
237
- # Save individual images
238
  for i, (img, prompt) in enumerate(zip(images, prompts)):
239
  safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-")
240
  path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png")
241
  save_image(img, path)
242
 
243
- # Log grid to tensorboard
244
  grid = make_grid(images, nrow=2, normalize=False)
245
  writer.add_image("samples", grid, step)
246
-
247
- # Log prompts
248
  writer.add_text("sample_prompts", "\n".join(prompts), step)
249
-
250
  print(f" ✓ Saved {len(images)} samples")
251
 
 
252
  # ============================================================================
253
- # COLLATE
254
  # ============================================================================
255
- def collate(batch):
256
- latents, t5_embeds, clip_embeds, prompts = [], [], [], []
257
- for b in batch:
258
- latents.append(torch.tensor(np.array(b["latent"]), dtype=DTYPE))
259
- t5_out, clip_pooled = encode_prompt(b["prompt"])
260
- t5_embeds.append(t5_out.squeeze(0))
261
- clip_embeds.append(clip_pooled.squeeze(0))
262
- prompts.append(b["prompt"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  return {
264
- "latents": torch.stack(latents).to(DEVICE),
265
- "t5_embeds": torch.stack(t5_embeds),
266
- "clip_pooled": torch.stack(clip_embeds),
267
- "prompts": prompts,
268
  }
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  # ============================================================================
271
  # CHECKPOINT FUNCTIONS
272
  # ============================================================================
273
  def load_weights(path):
274
- """Load weights from .safetensors or .pt file."""
275
  if path.endswith(".safetensors"):
276
- return load_file(path)
277
  elif path.endswith(".pt"):
278
  ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
279
  if isinstance(ckpt, dict):
280
- if "model" in ckpt:
281
- return ckpt["model"]
282
- elif "state_dict" in ckpt:
283
- return ckpt["state_dict"]
284
- else:
285
- # Check if it looks like a state dict (has tensor values)
286
- first_val = next(iter(ckpt.values()), None)
287
- if isinstance(first_val, torch.Tensor):
288
- return ckpt
289
- # Otherwise might have optimizer etc, look for model keys
290
- return ckpt
291
- return ckpt
292
  else:
293
- # Try safetensors first, then pt
294
  try:
295
- return load_file(path)
296
  except:
297
- return torch.load(path, map_location=DEVICE, weights_only=False)
 
 
 
 
 
 
 
 
298
 
299
  def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
300
- """Save checkpoint locally."""
301
  os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
302
 
 
 
 
 
303
  weights_path = path.replace(".pt", ".safetensors")
304
- save_file(model.state_dict(), weights_path)
305
 
306
- state = {
307
  "step": step,
308
  "epoch": epoch,
309
  "loss": loss,
310
  "optimizer": optimizer.state_dict(),
311
  "scheduler": scheduler.state_dict(),
312
- }
313
- torch.save(state, path)
314
  print(f" ✓ Saved checkpoint: step {step}")
315
  return weights_path
316
 
317
- def upload_checkpoint(weights_path, step, config, include_logs=True):
318
- """Upload checkpoint to HuggingFace Hub."""
 
319
  try:
320
- # Upload weights
321
  api.upload_file(
322
  path_or_fileobj=weights_path,
323
  path_in_repo=f"checkpoints/step_{step}.safetensors",
@@ -325,286 +416,86 @@ def upload_checkpoint(weights_path, step, config, include_logs=True):
325
  commit_message=f"Checkpoint step {step}",
326
  )
327
 
328
- # Upload config
329
  config_path = os.path.join(CHECKPOINT_DIR, "config.json")
330
  with open(config_path, "w") as f:
331
  json.dump(config.__dict__, f, indent=2)
332
- api.upload_file(
333
- path_or_fileobj=config_path,
334
- path_in_repo="config.json",
335
- repo_id=HF_REPO,
336
- )
337
-
338
- # Upload tensorboard logs
339
- if include_logs and os.path.exists(LOG_DIR):
340
- api.upload_folder(
341
- folder_path=LOG_DIR,
342
- path_in_repo="logs",
343
- repo_id=HF_REPO,
344
- commit_message=f"Logs at step {step}",
345
- )
346
-
347
- # Upload samples
348
- if os.path.exists(SAMPLE_DIR) and os.listdir(SAMPLE_DIR):
349
- api.upload_folder(
350
- folder_path=SAMPLE_DIR,
351
- path_in_repo="samples",
352
- repo_id=HF_REPO,
353
- commit_message=f"Samples at step {step}",
354
- )
355
 
356
- print(f" ✓ Uploaded to {HF_REPO}")
357
  except Exception as e:
358
  print(f" ⚠ Upload failed: {e}")
359
 
 
360
  def load_checkpoint(model, optimizer, scheduler, target):
361
- """
362
- Load checkpoint based on target specification.
363
 
364
- Args:
365
- target:
366
- None, "latest" - most recent checkpoint
367
- "best" - best model
368
- int (1500) - specific step
369
- "hub:step_1000" - specific hub checkpoint
370
- "local:/path/to/file.safetensors" or "local:/path/to/file.pt" - specific local file
371
- "none" - skip loading, start fresh
372
- """
373
- if target == "none":
374
- print("Starting fresh (no checkpoint loading)")
375
  return 0, 0
376
 
377
- start_step, start_epoch = 0, 0
378
-
379
- # Parse target
380
- if target is None or target == "latest":
381
- load_mode = "latest"
382
- load_path = None
383
- elif target == "best":
384
- load_mode = "best"
385
- load_path = None
386
- elif isinstance(target, int):
387
- load_mode = "step"
388
- load_path = target
389
- elif target.startswith("hub:"):
390
- load_mode = "hub"
391
- load_path = target[4:] # Remove "hub:" prefix
392
- elif target.startswith("local:"):
393
- load_mode = "local"
394
- load_path = target[6:] # Remove "local:" prefix
395
- else:
396
- print(f"Unknown target format: {target}, trying as step number")
397
  try:
398
- load_mode = "step"
399
- load_path = int(target)
400
- except:
401
- load_mode = "latest"
402
- load_path = None
403
-
404
- # Load based on mode
405
- if load_mode == "local":
406
- # Direct local file (.pt or .safetensors)
407
- if os.path.exists(load_path):
408
- weights = load_weights(load_path)
409
- model.load_state_dict(weights)
410
-
411
- # Try to find associated state file for optimizer/scheduler
412
- if load_path.endswith(".safetensors"):
413
- state_path = load_path.replace(".safetensors", ".pt")
414
- elif load_path.endswith(".pt"):
415
- # The .pt file might contain everything
416
- ckpt = torch.load(load_path, map_location=DEVICE, weights_only=False)
417
- if isinstance(ckpt, dict):
418
- # Debug: show what keys are in the checkpoint
419
- non_tensor_keys = [k for k in ckpt.keys() if not isinstance(ckpt.get(k), torch.Tensor)]
420
- if non_tensor_keys:
421
- print(f" Checkpoint keys: {non_tensor_keys}")
422
-
423
- # Extract step/epoch - try multiple common key names
424
- start_step = ckpt.get("step", ckpt.get("global_step", ckpt.get("iteration", 0)))
425
- start_epoch = ckpt.get("epoch", 0)
426
-
427
- # Also check for nested state dict
428
- if "state" in ckpt and isinstance(ckpt["state"], dict):
429
- start_step = ckpt["state"].get("step", start_step)
430
- start_epoch = ckpt["state"].get("epoch", start_epoch)
431
-
432
- # Try to load optimizer/scheduler if present
433
- if "optimizer" in ckpt:
434
- try:
435
- optimizer.load_state_dict(ckpt["optimizer"])
436
- if "scheduler" in ckpt:
437
- scheduler.load_state_dict(ckpt["scheduler"])
438
- except Exception as e:
439
- print(f" Note: Could not load optimizer state: {e}")
440
- state_path = None
441
  else:
442
- state_path = load_path + ".pt"
443
-
444
- if state_path and os.path.exists(state_path):
445
- state = torch.load(state_path, map_location=DEVICE, weights_only=False)
446
  try:
447
- start_step = state.get("step", start_step)
448
- start_epoch = state.get("epoch", start_epoch)
449
- if "optimizer" in state:
450
- optimizer.load_state_dict(state["optimizer"])
451
- if "scheduler" in state:
452
- scheduler.load_state_dict(state["scheduler"])
453
- except Exception as e:
454
- print(f" Note: Could not load optimizer state: {e}")
455
 
456
- print(f"✓ Loaded local: {load_path} (step {start_step})")
457
- return start_step, start_epoch
458
- else:
459
- print(f"⚠ Local file not found: {load_path}")
460
-
461
- elif load_mode == "hub":
462
- # Specific hub checkpoint - try both extensions
463
- for ext in [".safetensors", ".pt", ""]:
464
- try:
465
- if load_path.endswith((".safetensors", ".pt")):
466
- filename = load_path if "/" in load_path else f"checkpoints/{load_path}"
467
- else:
468
- filename = f"checkpoints/{load_path}{ext}"
469
- local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
470
- weights = load_weights(local_path)
471
- model.load_state_dict(weights)
472
- # Extract step from filename
473
- if "step_" in load_path:
474
- start_step = int(load_path.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
475
- print(f"✓ Loaded from Hub: {filename} (step {start_step})")
476
- return start_step, start_epoch
477
- except Exception as e:
478
- continue
479
- print(f"⚠ Could not load from hub: {load_path}")
480
-
481
- elif load_mode == "best":
482
- # Try hub best first (try both extensions)
483
- for ext in [".safetensors", ".pt"]:
484
- try:
485
- filename = f"model{ext}" if ext else "model.safetensors"
486
- local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
487
- weights = load_weights(local_path)
488
- model.load_state_dict(weights)
489
- print(f"✓ Loaded best model from Hub")
490
- return start_step, start_epoch
491
- except:
492
- continue
493
-
494
- # Try local best (both extensions)
495
- for ext in [".safetensors", ".pt"]:
496
- best_path = os.path.join(CHECKPOINT_DIR, f"best{ext}")
497
- if os.path.exists(best_path):
498
- weights = load_weights(best_path)
499
- model.load_state_dict(weights)
500
- # Try to load optimizer state
501
- state_path = best_path.replace(ext, ".pt") if ext == ".safetensors" else best_path
502
- if os.path.exists(state_path):
503
- state = torch.load(state_path, map_location=DEVICE, weights_only=False)
504
- if isinstance(state, dict) and "step" in state:
505
- start_step = state.get("step", 0)
506
- start_epoch = state.get("epoch", 0)
507
- print(f"✓ Loaded local best (step {start_step})")
508
- return start_step, start_epoch
509
-
510
- elif load_mode == "step":
511
- # Specific step number
512
- step_num = load_path
513
- # Try hub (both extensions)
514
- for ext in [".safetensors", ".pt"]:
515
- try:
516
- filename = f"checkpoints/step_{step_num}{ext}"
517
- local_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
518
- weights = load_weights(local_path)
519
- model.load_state_dict(weights)
520
- start_step = step_num
521
- print(f"✓ Loaded step {step_num} from Hub")
522
- return start_step, start_epoch
523
- except:
524
- continue
525
-
526
- # Try local (both extensions)
527
- for ext in [".safetensors", ".pt"]:
528
- local_path = os.path.join(CHECKPOINT_DIR, f"step_{step_num}{ext}")
529
- if os.path.exists(local_path):
530
- weights = load_weights(local_path)
531
- model.load_state_dict(weights)
532
- state_path = local_path.replace(".safetensors", ".pt") if ext == ".safetensors" else local_path
533
- if os.path.exists(state_path):
534
- state = torch.load(state_path, map_location=DEVICE, weights_only=False)
535
- if isinstance(state, dict):
536
- try:
537
- if "optimizer" in state:
538
- optimizer.load_state_dict(state["optimizer"])
539
- if "scheduler" in state:
540
- scheduler.load_state_dict(state["scheduler"])
541
- start_epoch = state.get("epoch", 0)
542
- except:
543
- pass
544
- start_step = step_num
545
- print(f"✓ Loaded local step {step_num}")
546
- return start_step, start_epoch
547
- print(f"⚠ Step {step_num} not found")
548
-
549
- # Default: latest
550
- # Try Hub first (both extensions)
551
- try:
552
- files = api.list_repo_files(repo_id=HF_REPO)
553
- checkpoints = [f for f in files if f.startswith("checkpoints/step_") and (f.endswith(".safetensors") or f.endswith(".pt"))]
554
- if checkpoints:
555
- # Sort by step number
556
- def get_step(f):
557
- return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
558
- checkpoints.sort(key=get_step)
559
- latest = checkpoints[-1]
560
- step = get_step(latest)
561
- local_path = hf_hub_download(repo_id=HF_REPO, filename=latest)
562
- weights = load_weights(local_path)
563
- model.load_state_dict(weights)
564
- start_step = step
565
- print(f"✓ Loaded latest from Hub: step {step}")
566
- return start_step, start_epoch
567
- except Exception as e:
568
- print(f"Hub check: {e}")
569
-
570
- # Try local (both extensions)
571
- if os.path.exists(CHECKPOINT_DIR):
572
- local_ckpts = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("step_") and (f.endswith(".safetensors") or f.endswith(".pt"))]
573
- # Filter to just weights files (not state .pt files that pair with .safetensors)
574
- local_ckpts = [f for f in local_ckpts if not (f.endswith(".pt") and f.replace(".pt", ".safetensors") in local_ckpts)]
575
- if local_ckpts:
576
- def get_step(f):
577
- return int(f.split("step_")[-1].replace(".safetensors", "").replace(".pt", ""))
578
- local_ckpts.sort(key=get_step)
579
- latest = local_ckpts[-1]
580
- step = get_step(latest)
581
- weights_path = os.path.join(CHECKPOINT_DIR, latest)
582
  weights = load_weights(weights_path)
583
- model.load_state_dict(weights)
584
- # Try to load optimizer state
585
- state_path = weights_path.replace(".safetensors", ".pt") if weights_path.endswith(".safetensors") else weights_path
586
- if os.path.exists(state_path):
587
- state = torch.load(state_path, map_location=DEVICE, weights_only=False)
588
- if isinstance(state, dict):
589
- try:
590
- if "optimizer" in state:
591
- optimizer.load_state_dict(state["optimizer"])
592
- if "scheduler" in state:
593
- scheduler.load_state_dict(state["scheduler"])
594
- start_epoch = state.get("epoch", 0)
595
- except:
596
- pass
597
- start_step = step
598
- print(f"✓ Loaded latest local: step {step}")
599
  return start_step, start_epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  print("No checkpoint found, starting fresh")
602
  return 0, 0
603
 
 
604
  # ============================================================================
605
- # DATALOADER
606
  # ============================================================================
607
- loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate, num_workers=0)
 
 
 
 
 
 
 
 
 
608
 
609
  # ============================================================================
610
  # MODEL
@@ -612,33 +503,46 @@ loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate,
612
  config = TinyFluxConfig()
613
  model = TinyFlux(config).to(DEVICE).to(DTYPE)
614
  print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
615
- model = torch.compile(model, mode="default")
616
 
617
  # ============================================================================
618
- # OPTIMIZER & SCHEDULER
619
  # ============================================================================
620
- opt = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.99), weight_decay=0.01)
 
 
 
 
 
 
 
621
  total_steps = len(loader) * EPOCHS // GRAD_ACCUM
622
  warmup = min(500, total_steps // 10)
623
 
 
624
  def lr_fn(step):
625
- if step < warmup: return step / warmup
 
626
  return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
627
 
 
628
  sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
629
 
630
  # ============================================================================
631
- # LOAD CHECKPOINT
632
  # ============================================================================
633
  print(f"\nLoad target: {LOAD_TARGET}")
634
  start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
635
 
636
- # Override start_step if RESUME_STEP is set
637
  if RESUME_STEP is not None:
638
  print(f"Overriding start_step: {start_step} -> {RESUME_STEP}")
639
  start_step = RESUME_STEP
640
 
641
- # Log config to tensorboard
 
 
 
 
 
642
  writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
643
  writer.add_text("training_config", json.dumps({
644
  "batch_size": BATCH_SIZE,
@@ -647,11 +551,10 @@ writer.add_text("training_config", json.dumps({
647
  "epochs": EPOCHS,
648
  "min_snr": MIN_SNR,
649
  "shift": SHIFT,
 
650
  }, indent=2), 0)
651
 
652
- # ============================================================================
653
- # SAMPLE PROMPTS FOR PERIODIC GENERATION
654
- # ============================================================================
655
  SAMPLE_PROMPTS = [
656
  "a photo of a cat sitting on a windowsill",
657
  "a beautiful sunset over mountains",
@@ -660,67 +563,55 @@ SAMPLE_PROMPTS = [
660
  ]
661
 
662
  # ============================================================================
663
- # TRAINING
664
  # ============================================================================
665
  print(f"\nTraining {EPOCHS} epochs, {total_steps} total steps")
666
  print(f"Resuming from step {start_step}, epoch {start_epoch}")
667
  print(f"Save: {SAVE_EVERY}, Upload: {UPLOAD_EVERY}, Sample: {SAMPLE_EVERY}, Log: {LOG_EVERY}")
 
668
 
669
  model.train()
670
  step = start_step
671
  best = float("inf")
672
 
 
 
 
673
  for ep in range(start_epoch, EPOCHS):
674
  ep_loss = 0
675
  ep_batches = 0
676
- pbar = tqdm(loader, desc=f"E{ep+1}")
677
 
678
  for i, batch in enumerate(pbar):
679
- latents = batch["latents"] # Ground truth data (VAE encoded images)
680
- t5 = batch["t5_embeds"]
681
- clip = batch["clip_pooled"]
 
682
 
683
  B, C, H, W = latents.shape
684
 
685
- # ================================================================
686
- # FLOW MATCHING FORMULATION
687
- # ================================================================
688
- # x_1 = data (what we want to generate)
689
- # x_0 = noise (where we start at inference)
690
- # x_t = (1-t)*x_0 + t*x_1 (linear interpolation)
691
- #
692
- # At t=0: x_t = x_0 (pure noise)
693
- # At t=1: x_t = x_1 (pure data)
694
- #
695
- # Velocity field: v = dx/dt = x_1 - x_0
696
- # Model learns to predict v given (x_t, t)
697
- #
698
- # At inference: start from noise, integrate v from t=0 to t=1
699
- # ================================================================
700
-
701
- # Reshape data to sequence format: (B, C, H, W) -> (B, H*W, C)
702
- data = latents.permute(0, 2, 3, 1).reshape(B, H*W, C) # x_1
703
- noise = torch.randn_like(data) # x_0
704
 
705
- # Sample timesteps with logit-normal distribution + Flux shift
706
- # This biases training towards higher t (closer to data)
707
  t = torch.sigmoid(torch.randn(B, device=DEVICE))
708
- t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1-1e-4)
709
 
710
- # Create noisy samples via linear interpolation
711
  t_expanded = t.view(B, 1, 1)
712
- x_t = (1 - t_expanded) * noise + t_expanded * data # Noisy sample at time t
713
 
714
- # Target velocity: direction from noise to data
715
  v_target = data - noise
716
 
717
- # Create position IDs for RoPE
718
  img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
719
 
720
- # Random guidance scale (for CFG training)
721
- guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1 # [1, 5]
722
 
723
- # Forward pass: predict velocity
724
  with torch.autocast("cuda", dtype=DTYPE):
725
  v_pred = model(
726
  hidden_states=x_t,
@@ -731,10 +622,8 @@ for ep in range(start_epoch, EPOCHS):
731
  guidance=guidance,
732
  )
733
 
734
- # Loss: MSE between predicted and target velocity
735
  loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2])
736
-
737
- # Min-SNR weighting: downweight easy timesteps (near t=0 or t=1)
738
  snr_weights = min_snr_weight(t)
739
  loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
740
  loss.backward()
@@ -743,38 +632,33 @@ for ep in range(start_epoch, EPOCHS):
743
  grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
744
  opt.step()
745
  sched.step()
746
- opt.zero_grad()
747
  step += 1
748
 
749
- # Tensorboard logging
750
  if step % LOG_EVERY == 0:
751
  writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step)
752
  writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
753
  writer.add_scalar("train/grad_norm", grad_norm.item(), step)
754
  writer.add_scalar("train/t_mean", t.mean().item(), step)
755
- writer.add_scalar("train/snr_weight_mean", snr_weights.mean().item(), step)
756
 
757
- # Generate samples
758
  if step % SAMPLE_EVERY == 0:
759
  print(f"\n Generating samples at step {step}...")
760
  images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
761
  save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
762
 
763
- # Save checkpoint
764
  if step % SAVE_EVERY == 0:
765
  ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
766
  weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path)
767
 
768
- # Upload
769
  if step % UPLOAD_EVERY == 0:
770
- upload_checkpoint(weights_path, step, config, include_logs=True)
771
 
772
  ep_loss += loss.item() * GRAD_ACCUM
773
  ep_batches += 1
774
- pbar.set_postfix(loss=f"{loss.item()*GRAD_ACCUM:.4f}", lr=f"{sched.get_last_lr()[0]:.1e}", step=step)
775
 
776
  avg = ep_loss / max(ep_batches, 1)
777
- print(f"Epoch {ep+1} loss: {avg:.4f}")
778
  writer.add_scalar("train/epoch_loss", avg, ep + 1)
779
 
780
  if avg < best:
@@ -787,7 +671,7 @@ for ep in range(start_epoch, EPOCHS):
787
  path_or_fileobj=weights_path,
788
  path_in_repo="model.safetensors",
789
  repo_id=HF_REPO,
790
- commit_message=f"Best model (epoch {ep+1}, loss {avg:.4f})",
791
  )
792
  print(f" ✓ Uploaded best to {HF_REPO}")
793
  except Exception as e:
@@ -800,20 +684,16 @@ print("\nSaving final model...")
800
  final_path = os.path.join(CHECKPOINT_DIR, "final.pt")
801
  weights_path = save_checkpoint(model, opt, sched, step, EPOCHS, best, final_path)
802
 
803
- # Final samples
804
  print("Generating final samples...")
805
  images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
806
  save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
807
 
808
- # Final upload
809
  try:
810
  api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
811
  config_path = os.path.join(CHECKPOINT_DIR, "config.json")
812
  with open(config_path, "w") as f:
813
  json.dump(config.__dict__, f, indent=2)
814
  api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO)
815
- api.upload_folder(folder_path=LOG_DIR, path_in_repo="logs", repo_id=HF_REPO)
816
- api.upload_folder(folder_path=SAMPLE_DIR, path_in_repo="samples", repo_id=HF_REPO)
817
  print(f"\n✓ Training complete! https://huggingface.co/{HF_REPO}")
818
  except Exception as e:
819
  print(f"\n⚠ Final upload failed: {e}")
 
1
  # ============================================================================
2
+ # TinyFlux Training Cell - OPTIMIZED
3
  # ============================================================================
4
+ # Optimizations:
5
+ # - TF32 and cuDNN settings for faster matmuls
6
+ # - Fused AdamW optimizer
7
+ # - Pre-encoded prompts (encode once at startup, not per batch)
8
+ # - Batched prompt encoding
9
+ # - DataLoader with num_workers and pin_memory
10
+ # - torch.inference_mode() for sampling
11
+ # - Cached img_ids in model
12
+ # - torch.compile with max-autotune
13
  # ============================================================================
14
 
15
  import torch
 
27
  import json
28
  from datetime import datetime
29
 
30
+ # ============================================================================
31
+ # CUDA OPTIMIZATIONS - Set these BEFORE model creation
32
+ # ============================================================================
33
+ # New PyTorch 2.x API for TF32
34
+ torch.backends.cuda.matmul.allow_tf32 = True
35
+ torch.backends.cudnn.allow_tf32 = True
36
+ torch.backends.cudnn.benchmark = True
37
+ torch.set_float32_matmul_precision('high')
38
+
39
+ # Suppress the deprecation warning (settings still work)
40
+ import warnings
41
+ warnings.filterwarnings('ignore', message='.*TF32.*')
42
+
43
  # ============================================================================
44
  # CONFIG
45
  # ============================================================================
46
+ BATCH_SIZE = 128
47
+ GRAD_ACCUM = 1
48
  LR = 1e-4
49
  EPOCHS = 10
50
  MAX_SEQ = 128
 
55
 
56
  # HuggingFace Hub
57
  HF_REPO = "AbstractPhil/tiny-flux"
58
+ SAVE_EVERY = 1000
59
+ UPLOAD_EVERY = 1000
60
+ SAMPLE_EVERY = 500
61
+ LOG_EVERY = 10
62
+
63
+ # Checkpoint loading
64
+ LOAD_TARGET = "hub:step_24000" # "latest", "best", int, "hub:step_X", "local:path", "none"
65
+ RESUME_STEP = None
66
+
67
+ # Paths
 
 
 
 
 
 
 
 
 
 
68
  CHECKPOINT_DIR = "./tiny_flux_checkpoints"
69
  LOG_DIR = "./tiny_flux_logs"
70
  SAMPLE_DIR = "./tiny_flux_samples"
 
78
  # ============================================================================
79
  print("Setting up HuggingFace Hub...")
80
  api = HfApi()
 
81
  try:
82
  api.create_repo(repo_id=HF_REPO, exist_ok=True, repo_type="model")
83
  print(f"✓ Repo ready: {HF_REPO}")
 
95
  # LOAD DATASET
96
  # ============================================================================
97
  print("\nLoading dataset...")
98
+ ds = load_dataset("AbstractPhil/flux-schnell-teacher-latents", "train_3_512", split="train")
99
  print(f"Samples: {len(ds)}")
100
 
101
  # ============================================================================
 
117
  # ============================================================================
118
  print("Loading Flux VAE for samples...")
119
  from diffusers import AutoencoderKL
120
+
121
  vae = AutoencoderKL.from_pretrained(
122
+ "black-forest-labs/FLUX.1-schnell",
123
  subfolder="vae",
124
  torch_dtype=DTYPE
125
  ).to(DEVICE).eval()
126
  for p in vae.parameters(): p.requires_grad = False
127
 
128
  # ============================================================================
129
+ # BATCHED ENCODING - Much faster than one-by-one
130
+ # ============================================================================
131
+ @torch.inference_mode()
132
+ def encode_prompts_batched(prompts: list) -> tuple:
133
+ """Encode multiple prompts at once - MUCH faster than loop."""
134
+ # T5 encoding
135
+ t5_in = t5_tok(
136
+ prompts,
137
+ max_length=MAX_SEQ,
138
+ padding="max_length",
139
+ truncation=True,
140
+ return_tensors="pt"
141
+ ).to(DEVICE)
142
+ t5_out = t5_enc(
143
+ input_ids=t5_in.input_ids,
144
+ attention_mask=t5_in.attention_mask
145
+ ).last_hidden_state
146
+
147
+ # CLIP encoding
148
+ clip_in = clip_tok(
149
+ prompts,
150
+ max_length=77,
151
+ padding="max_length",
152
+ truncation=True,
153
+ return_tensors="pt"
154
+ ).to(DEVICE)
155
+ clip_out = clip_enc(
156
+ input_ids=clip_in.input_ids,
157
+ attention_mask=clip_in.attention_mask
158
+ )
159
 
 
 
160
  return t5_out, clip_out.pooler_output
161
 
162
+
163
+ @torch.inference_mode()
164
+ def encode_prompt(prompt: str) -> tuple:
165
+ """Encode single prompt (for compatibility)."""
166
+ return encode_prompts_batched([prompt])
167
+
168
+
169
  # ============================================================================
170
+ # PRE-ENCODE ALL PROMPTS (with disk caching)
 
 
 
 
 
 
 
 
 
 
171
  # ============================================================================
172
+ print("\nPre-encoding prompts...")
173
+ PRECOMPUTE_ENCODINGS = True
174
+ ENCODING_CACHE_DIR = "./encoding_cache"
175
+ os.makedirs(ENCODING_CACHE_DIR, exist_ok=True)
176
 
177
+ # Cache filename based on dataset size and encoder
178
+ cache_file = os.path.join(ENCODING_CACHE_DIR, f"encodings_{len(ds)}_t5base_clipL.pt")
179
+
180
+ if PRECOMPUTE_ENCODINGS:
181
+ if os.path.exists(cache_file):
182
+ # Load from cache
183
+ print(f"Loading cached encodings from {cache_file}...")
184
+ cached = torch.load(cache_file, weights_only=True)
185
+ all_t5_embeds = cached["t5_embeds"]
186
+ all_clip_pooled = cached["clip_pooled"]
187
+ print(f"✓ Loaded cached encodings")
188
+ else:
189
+ # Get all prompts via columnar access (instant, no iteration)
190
+ print("Encoding prompts (will cache for future runs)...")
191
+ all_prompts = ds["prompt"] # Columnar access - instant!
192
+
193
+ encode_batch_size = 64
194
+ all_t5_embeds = []
195
+ all_clip_pooled = []
196
+
197
+ for i in tqdm(range(0, len(all_prompts), encode_batch_size), desc="Encoding"):
198
+ batch_prompts = all_prompts[i:i+encode_batch_size]
199
+ t5_out, clip_out = encode_prompts_batched(batch_prompts)
200
+ all_t5_embeds.append(t5_out.cpu())
201
+ all_clip_pooled.append(clip_out.cpu())
202
+
203
+ all_t5_embeds = torch.cat(all_t5_embeds, dim=0)
204
+ all_clip_pooled = torch.cat(all_clip_pooled, dim=0)
205
+
206
+ # Save cache (~750MB for 10k samples)
207
+ torch.save({
208
+ "t5_embeds": all_t5_embeds,
209
+ "clip_pooled": all_clip_pooled,
210
+ }, cache_file)
211
+ print(f"✓ Saved encoding cache to {cache_file}")
212
 
213
+ print(f" T5 embeds: {all_t5_embeds.shape}")
214
+ print(f" CLIP pooled: {all_clip_pooled.shape}")
215
+
216
+
217
+ # ============================================================================
218
+ # FLOW MATCHING HELPERS
219
+ # ============================================================================
220
+ def flux_shift(t, s=SHIFT):
221
+ """Flux timestep shift for training distribution."""
222
  return s * t / (1 + (s - 1) * t)
223
 
 
 
 
224
 
225
  def min_snr_weight(t, gamma=MIN_SNR):
226
+ """Min-SNR weighting to balance loss across timesteps."""
 
 
 
 
227
  snr = (t / (1 - t).clamp(min=1e-5)).pow(2)
228
  return torch.clamp(snr, max=gamma) / snr.clamp(min=1e-5)
229
 
230
+
231
  # ============================================================================
232
+ # SAMPLING FUNCTION - Optimized
233
  # ============================================================================
234
+ @torch.inference_mode()
235
  def generate_samples(model, prompts, num_steps=20, guidance_scale=3.5, H=64, W=64):
236
+ """Generate sample images using Euler sampling."""
 
 
 
 
 
237
  model.eval()
238
  B = len(prompts)
239
+ C = 16
240
 
241
+ # Batch encode prompts
242
+ t5_embeds, clip_pooleds = encode_prompts_batched(prompts)
243
+ t5_embeds = t5_embeds.to(DTYPE)
244
+ clip_pooleds = clip_pooleds.to(DTYPE)
 
 
 
 
245
 
246
+ # Start from pure noise
247
  x = torch.randn(B, H * W, C, device=DEVICE, dtype=DTYPE)
248
 
249
+ # Create image IDs (cached in optimized model)
250
  img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
251
 
252
+ # Timesteps with flux_shift
253
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
254
+ timesteps = flux_shift(t_linear, s=SHIFT)
255
 
256
+ # Euler sampling
257
  for i in range(num_steps):
258
  t_curr = timesteps[i]
259
  t_next = timesteps[i + 1]
260
+ dt = t_next - t_curr
261
 
262
+ t_batch = t_curr.expand(B).to(DTYPE)
 
 
263
  guidance = torch.full((B,), guidance_scale, device=DEVICE, dtype=DTYPE)
264
+
265
  v_cond = model(
266
  hidden_states=x,
267
  encoder_hidden_states=t5_embeds,
 
271
  guidance=guidance,
272
  )
273
 
 
274
  x = x + v_cond * dt
275
 
276
+ # Decode
277
  latents = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
 
 
278
  latents = latents / vae.config.scaling_factor
279
  images = vae.decode(latents.to(vae.dtype)).sample
280
  images = (images / 2 + 0.5).clamp(0, 1)
 
282
  model.train()
283
  return images
284
 
285
+
286
  def save_samples(images, prompts, step, save_dir):
287
+ """Save sample images."""
288
  from torchvision.utils import make_grid, save_image
289
 
 
290
  for i, (img, prompt) in enumerate(zip(images, prompts)):
291
  safe_prompt = prompt[:50].replace(" ", "_").replace("/", "-")
292
  path = os.path.join(save_dir, f"step{step}_{i}_{safe_prompt}.png")
293
  save_image(img, path)
294
 
 
295
  grid = make_grid(images, nrow=2, normalize=False)
296
  writer.add_image("samples", grid, step)
 
 
297
  writer.add_text("sample_prompts", "\n".join(prompts), step)
 
298
  print(f" ✓ Saved {len(images)} samples")
299
 
300
+
301
  # ============================================================================
302
+ # OPTIMIZED COLLATE - Returns CPU tensors (GPU transfer in training loop)
303
  # ============================================================================
304
+ def collate_preencoded(batch):
305
+ """Collate using pre-encoded embeddings - returns CPU tensors."""
306
+ indices = [b["__index__"] for b in batch]
307
+ latents = torch.stack([
308
+ torch.tensor(np.array(b["latent"]), dtype=DTYPE)
309
+ for b in batch
310
+ ])
311
+
312
+ # Return CPU tensors - move to GPU in training loop
313
+ return {
314
+ "latents": latents,
315
+ "t5_embeds": all_t5_embeds[indices].to(DTYPE),
316
+ "clip_pooled": all_clip_pooled[indices].to(DTYPE),
317
+ }
318
+
319
+
320
+ def collate_online(batch):
321
+ """Collate with online encoding - returns CPU tensors."""
322
+ prompts = [b["prompt"] for b in batch]
323
+ latents = torch.stack([
324
+ torch.tensor(np.array(b["latent"]), dtype=DTYPE)
325
+ for b in batch
326
+ ])
327
+
328
+ # This still needs CUDA for encoding, so use num_workers=0
329
+ t5_embeds, clip_pooled = encode_prompts_batched(prompts)
330
+
331
  return {
332
+ "latents": latents,
333
+ "t5_embeds": t5_embeds.cpu().to(DTYPE),
334
+ "clip_pooled": clip_pooled.cpu().to(DTYPE),
 
335
  }
336
 
337
+
338
+ # Simple wrapper to add index without touching the data
339
+ class IndexedDataset:
340
+ """Wraps dataset to add __index__ field without expensive ds.map()"""
341
+ def __init__(self, ds):
342
+ self.ds = ds
343
+ def __len__(self):
344
+ return len(self.ds)
345
+ def __getitem__(self, idx):
346
+ item = dict(self.ds[idx])
347
+ item["__index__"] = idx
348
+ return item
349
+
350
+ # Choose collate strategy
351
+ if PRECOMPUTE_ENCODINGS:
352
+ ds = IndexedDataset(ds) # Instant, no iteration
353
+ collate_fn = collate_preencoded
354
+ num_workers = 2
355
+ else:
356
+ collate_fn = collate_online
357
+ num_workers = 0
358
+
359
+
360
  # ============================================================================
361
  # CHECKPOINT FUNCTIONS
362
  # ============================================================================
363
  def load_weights(path):
364
+ """Load weights, handling torch.compile prefix."""
365
  if path.endswith(".safetensors"):
366
+ state_dict = load_file(path)
367
  elif path.endswith(".pt"):
368
  ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
369
  if isinstance(ckpt, dict):
370
+ state_dict = ckpt.get("model", ckpt.get("state_dict", ckpt))
371
+ else:
372
+ state_dict = ckpt
 
 
 
 
 
 
 
 
 
373
  else:
 
374
  try:
375
+ state_dict = load_file(path)
376
  except:
377
+ state_dict = torch.load(path, map_location=DEVICE, weights_only=False)
378
+
379
+ # Strip torch.compile prefix
380
+ if isinstance(state_dict, dict) and any(k.startswith("_orig_mod.") for k in state_dict.keys()):
381
+ print(" Stripping torch.compile prefix...")
382
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
383
+
384
+ return state_dict
385
+
386
 
387
  def save_checkpoint(model, optimizer, scheduler, step, epoch, loss, path):
388
+ """Save checkpoint, stripping torch.compile prefix."""
389
  os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
390
 
391
+ state_dict = model.state_dict()
392
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
393
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
394
+
395
  weights_path = path.replace(".pt", ".safetensors")
396
+ save_file(state_dict, weights_path)
397
 
398
+ torch.save({
399
  "step": step,
400
  "epoch": epoch,
401
  "loss": loss,
402
  "optimizer": optimizer.state_dict(),
403
  "scheduler": scheduler.state_dict(),
404
+ }, path)
 
405
  print(f" ✓ Saved checkpoint: step {step}")
406
  return weights_path
407
 
408
+
409
+ def upload_checkpoint(weights_path, step, config):
410
+ """Upload to HuggingFace Hub."""
411
  try:
 
412
  api.upload_file(
413
  path_or_fileobj=weights_path,
414
  path_in_repo=f"checkpoints/step_{step}.safetensors",
 
416
  commit_message=f"Checkpoint step {step}",
417
  )
418
 
 
419
  config_path = os.path.join(CHECKPOINT_DIR, "config.json")
420
  with open(config_path, "w") as f:
421
  json.dump(config.__dict__, f, indent=2)
422
+ api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
+ print(f" ✓ Uploaded step {step} to {HF_REPO}")
425
  except Exception as e:
426
  print(f" ⚠ Upload failed: {e}")
427
 
428
+
429
  def load_checkpoint(model, optimizer, scheduler, target):
430
+ """Load checkpoint from various sources."""
431
+ start_step, start_epoch = 0, 0
432
 
433
+ if target == "none" or target is None:
434
+ print("Starting fresh (no checkpoint)")
 
 
 
 
 
 
 
 
 
435
  return 0, 0
436
 
437
+ # Hub loading
438
+ if target == "hub" or (isinstance(target, str) and target.startswith("hub:")):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  try:
440
+ if target == "hub":
441
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  else:
443
+ step_name = target.split(":")[1]
 
 
 
444
  try:
445
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename=f"checkpoints/{step_name}.safetensors")
446
+ except:
447
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename=f"checkpoints/{step_name}.pt")
448
+ start_step = int(step_name.split("_")[-1]) if "_" in step_name else 0
 
 
 
 
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  weights = load_weights(weights_path)
451
+ # strict=False: ignore missing buffers (sin_basis, freqs) - they're precomputed constants
452
+ missing, unexpected = model.load_state_dict(weights, strict=False)
453
+ if missing:
454
+ # Filter out expected missing buffers
455
+ expected_missing = {'time_in.sin_basis', 'guidance_in.sin_basis',
456
+ 'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'}
457
+ actual_missing = set(missing) - expected_missing
458
+ if actual_missing:
459
+ print(f" ⚠ Unexpected missing keys: {actual_missing}")
460
+ else:
461
+ print(f" ✓ Missing only precomputed buffers (OK)")
462
+ print(f" Loaded from hub: {target}")
 
 
 
 
463
  return start_step, start_epoch
464
+ except Exception as e:
465
+ print(f"Hub load failed: {e}")
466
+ return 0, 0
467
+
468
+ # Local loading
469
+ if isinstance(target, str) and target.startswith("local:"):
470
+ path = target.split(":", 1)[1]
471
+ weights = load_weights(path)
472
+ missing, unexpected = model.load_state_dict(weights, strict=False)
473
+ if missing:
474
+ expected_missing = {'time_in.sin_basis', 'guidance_in.sin_basis',
475
+ 'rope.freqs_0', 'rope.freqs_1', 'rope.freqs_2'}
476
+ actual_missing = set(missing) - expected_missing
477
+ if actual_missing:
478
+ print(f" ⚠ Unexpected missing keys: {actual_missing}")
479
+ print(f"✓ Loaded from local: {path}")
480
+ return 0, 0
481
 
482
  print("No checkpoint found, starting fresh")
483
  return 0, 0
484
 
485
+
486
  # ============================================================================
487
+ # DATALOADER - Optimized
488
  # ============================================================================
489
+ loader = DataLoader(
490
+ ds,
491
+ batch_size=BATCH_SIZE,
492
+ shuffle=True,
493
+ collate_fn=collate_fn,
494
+ num_workers=num_workers, # 2 for precomputed, 0 for online
495
+ pin_memory=True,
496
+ persistent_workers=(num_workers > 0),
497
+ prefetch_factor=2 if num_workers > 0 else None,
498
+ )
499
 
500
  # ============================================================================
501
  # MODEL
 
503
  config = TinyFluxConfig()
504
  model = TinyFlux(config).to(DEVICE).to(DTYPE)
505
  print(f"\nParams: {sum(p.numel() for p in model.parameters()):,}")
 
506
 
507
  # ============================================================================
508
+ # OPTIMIZER - Fused for speed
509
  # ============================================================================
510
+ opt = torch.optim.AdamW(
511
+ model.parameters(),
512
+ lr=LR,
513
+ betas=(0.9, 0.99),
514
+ weight_decay=0.01,
515
+ fused=True,
516
+ )
517
+
518
  total_steps = len(loader) * EPOCHS // GRAD_ACCUM
519
  warmup = min(500, total_steps // 10)
520
 
521
+
522
  def lr_fn(step):
523
+ if step < warmup:
524
+ return step / warmup
525
  return 0.5 * (1 + math.cos(math.pi * (step - warmup) / (total_steps - warmup)))
526
 
527
+
528
  sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_fn)
529
 
530
  # ============================================================================
531
+ # LOAD CHECKPOINT (before compile!)
532
  # ============================================================================
533
  print(f"\nLoad target: {LOAD_TARGET}")
534
  start_step, start_epoch = load_checkpoint(model, opt, sched, LOAD_TARGET)
535
 
 
536
  if RESUME_STEP is not None:
537
  print(f"Overriding start_step: {start_step} -> {RESUME_STEP}")
538
  start_step = RESUME_STEP
539
 
540
+ # ============================================================================
541
+ # COMPILE MODEL (after loading weights)
542
+ # ============================================================================
543
+ model = torch.compile(model, mode="default")
544
+
545
+ # Log config
546
  writer.add_text("config", json.dumps(config.__dict__, indent=2), 0)
547
  writer.add_text("training_config", json.dumps({
548
  "batch_size": BATCH_SIZE,
 
551
  "epochs": EPOCHS,
552
  "min_snr": MIN_SNR,
553
  "shift": SHIFT,
554
+ "optimizations": ["TF32", "fused_adamw", "precomputed_encodings", "flash_attention", "torch.compile"]
555
  }, indent=2), 0)
556
 
557
+ # Sample prompts
 
 
558
  SAMPLE_PROMPTS = [
559
  "a photo of a cat sitting on a windowsill",
560
  "a beautiful sunset over mountains",
 
563
  ]
564
 
565
  # ============================================================================
566
+ # TRAINING LOOP
567
  # ============================================================================
568
  print(f"\nTraining {EPOCHS} epochs, {total_steps} total steps")
569
  print(f"Resuming from step {start_step}, epoch {start_epoch}")
570
  print(f"Save: {SAVE_EVERY}, Upload: {UPLOAD_EVERY}, Sample: {SAMPLE_EVERY}, Log: {LOG_EVERY}")
571
+ print("Optimizations: TF32, fused AdamW, pre-encoded prompts, Flash Attention, torch.compile")
572
 
573
  model.train()
574
  step = start_step
575
  best = float("inf")
576
 
577
+ # Pre-create img_ids for common resolution (will be cached)
578
+ _cached_img_ids = None
579
+
580
  for ep in range(start_epoch, EPOCHS):
581
  ep_loss = 0
582
  ep_batches = 0
583
+ pbar = tqdm(loader, desc=f"E{ep + 1}")
584
 
585
  for i, batch in enumerate(pbar):
586
+ # Move to GPU here (not in collate, to support multiprocessing)
587
+ latents = batch["latents"].to(DEVICE, non_blocking=True)
588
+ t5 = batch["t5_embeds"].to(DEVICE, non_blocking=True)
589
+ clip = batch["clip_pooled"].to(DEVICE, non_blocking=True)
590
 
591
  B, C, H, W = latents.shape
592
 
593
+ # Reshape: (B, C, H, W) -> (B, H*W, C)
594
+ data = latents.permute(0, 2, 3, 1).reshape(B, H * W, C)
595
+ noise = torch.randn_like(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
+ # Sample timesteps with logit-normal + flux shift
 
598
  t = torch.sigmoid(torch.randn(B, device=DEVICE))
599
+ t = flux_shift(t, s=SHIFT).to(DTYPE).clamp(1e-4, 1 - 1e-4)
600
 
601
+ # Linear interpolation
602
  t_expanded = t.view(B, 1, 1)
603
+ x_t = (1 - t_expanded) * noise + t_expanded * data
604
 
605
+ # Velocity target
606
  v_target = data - noise
607
 
608
+ # Get img_ids (cached in model)
609
  img_ids = TinyFlux.create_img_ids(B, H, W, DEVICE)
610
 
611
+ # Random guidance
612
+ guidance = torch.rand(B, device=DEVICE, dtype=DTYPE) * 4 + 1
613
 
614
+ # Forward
615
  with torch.autocast("cuda", dtype=DTYPE):
616
  v_pred = model(
617
  hidden_states=x_t,
 
622
  guidance=guidance,
623
  )
624
 
625
+ # Loss with Min-SNR weighting
626
  loss_raw = F.mse_loss(v_pred, v_target, reduction="none").mean(dim=[1, 2])
 
 
627
  snr_weights = min_snr_weight(t)
628
  loss = (loss_raw * snr_weights).mean() / GRAD_ACCUM
629
  loss.backward()
 
632
  grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
633
  opt.step()
634
  sched.step()
635
+ opt.zero_grad(set_to_none=True) # Slightly faster than zero_grad()
636
  step += 1
637
 
 
638
  if step % LOG_EVERY == 0:
639
  writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, step)
640
  writer.add_scalar("train/lr", sched.get_last_lr()[0], step)
641
  writer.add_scalar("train/grad_norm", grad_norm.item(), step)
642
  writer.add_scalar("train/t_mean", t.mean().item(), step)
 
643
 
 
644
  if step % SAMPLE_EVERY == 0:
645
  print(f"\n Generating samples at step {step}...")
646
  images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
647
  save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
648
 
 
649
  if step % SAVE_EVERY == 0:
650
  ckpt_path = os.path.join(CHECKPOINT_DIR, f"step_{step}.pt")
651
  weights_path = save_checkpoint(model, opt, sched, step, ep, loss.item(), ckpt_path)
652
 
 
653
  if step % UPLOAD_EVERY == 0:
654
+ upload_checkpoint(weights_path, step, config)
655
 
656
  ep_loss += loss.item() * GRAD_ACCUM
657
  ep_batches += 1
658
+ pbar.set_postfix(loss=f"{loss.item() * GRAD_ACCUM:.4f}", lr=f"{sched.get_last_lr()[0]:.1e}", step=step)
659
 
660
  avg = ep_loss / max(ep_batches, 1)
661
+ print(f"Epoch {ep + 1} loss: {avg:.4f}")
662
  writer.add_scalar("train/epoch_loss", avg, ep + 1)
663
 
664
  if avg < best:
 
671
  path_or_fileobj=weights_path,
672
  path_in_repo="model.safetensors",
673
  repo_id=HF_REPO,
674
+ commit_message=f"Best model (epoch {ep + 1}, loss {avg:.4f})",
675
  )
676
  print(f" ✓ Uploaded best to {HF_REPO}")
677
  except Exception as e:
 
684
  final_path = os.path.join(CHECKPOINT_DIR, "final.pt")
685
  weights_path = save_checkpoint(model, opt, sched, step, EPOCHS, best, final_path)
686
 
 
687
  print("Generating final samples...")
688
  images = generate_samples(model, SAMPLE_PROMPTS, num_steps=20)
689
  save_samples(images, SAMPLE_PROMPTS, step, SAMPLE_DIR)
690
 
 
691
  try:
692
  api.upload_file(path_or_fileobj=weights_path, path_in_repo="model.safetensors", repo_id=HF_REPO)
693
  config_path = os.path.join(CHECKPOINT_DIR, "config.json")
694
  with open(config_path, "w") as f:
695
  json.dump(config.__dict__, f, indent=2)
696
  api.upload_file(path_or_fileobj=config_path, path_in_repo="config.json", repo_id=HF_REPO)
 
 
697
  print(f"\n✓ Training complete! https://huggingface.co/{HF_REPO}")
698
  except Exception as e:
699
  print(f"\n⚠ Final upload failed: {e}")