| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch |
| import os |
|
|
| CACHE_DIR = "/content/latent_cache" |
| CACHE_FILE = os.path.join(CACHE_DIR, "imagenet_synthetic_flux_10k.pt") |
| os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
| if os.path.exists(CACHE_FILE): |
| print(f"β Cache exists: {CACHE_FILE}") |
| else: |
| from sd15_trainer_geo.pipeline import load_pipeline |
| from sd15_trainer_geo.trainer import pre_encode_hf_dataset |
|
|
| |
| pipe = load_pipeline(device="cuda", dtype=torch.float16) |
|
|
| pre_encode_hf_dataset( |
| pipe, |
| dataset_name="AbstractPhil/imagenet-synthetic", |
| subset="flux_schnell_512", |
| split="train", |
| image_column="image", |
| prompt_column="prompt", |
| output_path=CACHE_FILE, |
| image_size=512, |
| batch_size=16, |
| ) |
|
|
| |
| del pipe |
| torch.cuda.empty_cache() |
| print("β Encoding complete, VRAM cleared") |
|
|
| |
| |
| |
| from sd15_trainer_geo.pipeline import load_pipeline |
| from sd15_trainer_geo.trainer import TrainConfig, Trainer, LatentDataset |
| from sd15_trainer_geo.generate import generate, show_images, save_images |
|
|
| pipe = load_pipeline(device="cuda", dtype=torch.float16) |
| pipe.unet.load_pretrained( |
| repo_id="AbstractPhil/tinyflux-experts", |
| subfolder="", |
| filename="sd15-flow-lune-unet.safetensors", |
| ) |
|
|
| |
| print("\n--- Pre-training baseline ---") |
| pre_out = generate( |
| pipe, |
| ["a tabby cat on a windowsill", |
| "mountains at sunset, landscape painting", |
| "a bowl of ramen, studio photography", |
| "an astronaut riding a horse on mars"], |
| num_steps=25, cfg_scale=7.5, shift=2.5, seed=42, |
| ) |
| save_images(pre_out, "/content/baseline_samples") |
| show_images(pre_out) |
|
|
| |
| |
| |
| dataset = LatentDataset(CACHE_FILE) |
|
|
| |
| |
| config = TrainConfig( |
| |
| num_steps=1667, |
| batch_size=6, |
| base_lr=1e-4, |
| weight_decay=0.01, |
|
|
| |
| shift=2.5, |
| t_sample="logit_normal", |
| logit_normal_mean=0.0, |
| logit_normal_std=1.0, |
| t_min=0.001, |
| t_max=1.0, |
|
|
| |
| cfg_dropout=0.1, |
|
|
| |
| min_snr_gamma=5.0, |
|
|
| |
| geo_loss_weight=0.01, |
| geo_loss_warmup=200, |
|
|
| |
| lr_scheduler="cosine", |
| warmup_steps=100, |
| min_lr=1e-6, |
|
|
| |
| use_amp=True, |
| grad_clip=1.0, |
|
|
| |
| log_every=50, |
| sample_every=500, |
| save_every=500, |
| sample_prompts=[ |
| "a tabby cat sitting on a windowsill", |
| "mountains at sunset, landscape painting", |
| "a bowl of ramen, studio photography", |
| "an astronaut riding a horse on mars", |
| ], |
| sample_steps=25, |
| sample_cfg=7.5, |
|
|
| |
| output_dir="/content/geo_train_imagenet", |
| hub_repo_id=None, |
|
|
| |
| num_workers=2, |
| pin_memory=True, |
| seed=42, |
| ) |
|
|
| trainer = Trainer(pipe, config) |
| trainer.fit(dataset) |
|
|
| |
| |
| |
| print("\n--- Post-training samples ---") |
| post_out = generate( |
| pipe, |
| ["a tabby cat on a windowsill", |
| "mountains at sunset, landscape painting", |
| "a bowl of ramen, studio photography", |
| "an astronaut riding a horse on mars"], |
| num_steps=25, cfg_scale=7.5, shift=2.5, seed=42, |
| ) |
| save_images(post_out, "/content/post_train_samples") |
| show_images(post_out) |
|
|
| |
| print("\n--- Novel prompts (not in training set) ---") |
| novel_out = generate( |
| pipe, |
| ["a cyberpunk cityscape at night with neon lights", |
| "a golden retriever playing in autumn leaves", |
| "a steampunk clocktower, detailed illustration", |
| "an underwater coral reef, macro photography"], |
| num_steps=25, cfg_scale=7.5, shift=2.5, seed=123, |
| ) |
| save_images(novel_out, "/content/novel_samples") |
| show_images(novel_out) |
|
|
| |
| print(f"\nTraining: {len(trainer.log_history)} logged steps") |
| if trainer.log_history: |
| first = trainer.log_history[0] |
| last = trainer.log_history[-1] |
| print(f" Loss: {first['loss']:.4f} β {last['loss']:.4f}") |
| print(f" Task: {first['task_loss']:.4f} β {last['task_loss']:.4f}") |
| print(f" Geo: {first['geo_loss']:.6f} β {last['geo_loss']:.6f}") |
| print(f" t_mean: {last.get('t_mean', 0):.3f} Β± {last.get('t_std', 0):.3f}") |