File size: 17,211 Bytes
c4858e4
ce2ad4d
 
 
67a401e
 
 
 
d0236fe
67a401e
c4858e4
 
 
 
 
ce2ad4d
c4858e4
 
 
 
 
ce2ad4d
 
 
 
 
 
 
 
 
 
d0236fe
67a401e
ce2ad4d
 
 
 
 
 
 
67a401e
ce2ad4d
 
0b46772
 
ce2ad4d
 
 
57090a0
d0236fe
67a401e
0b46772
 
 
 
 
 
 
67a401e
0b46772
ce2ad4d
 
c4858e4
 
57090a0
 
 
 
 
 
 
 
 
67a401e
 
 
 
 
57090a0
 
d0236fe
 
 
 
 
 
 
c4858e4
 
57090a0
 
c4858e4
57090a0
 
d0236fe
551424e
 
 
d0236fe
ce2ad4d
c4858e4
 
2403335
ce2ad4d
 
c4858e4
 
67a401e
c4858e4
 
 
ce2ad4d
 
 
c4858e4
ce2ad4d
c4858e4
 
ce2ad4d
c4858e4
 
 
 
 
ce2ad4d
c4858e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce2ad4d
 
 
 
 
67a401e
0b46772
67a401e
57090a0
a1ff09a
57090a0
ce2ad4d
 
 
 
 
 
0b46772
67a401e
 
ce2ad4d
 
 
 
 
0b46772
ce2ad4d
57090a0
 
ce2ad4d
 
67a401e
ce2ad4d
 
 
 
57090a0
ce2ad4d
 
 
 
57090a0
ce2ad4d
 
67a401e
 
 
 
 
 
 
 
 
 
d0236fe
67a401e
57090a0
0b46772
57090a0
 
67a401e
ce2ad4d
 
 
57090a0
ce2ad4d
57090a0
 
ce2ad4d
57090a0
 
 
 
 
 
 
ce2ad4d
67a401e
ce2ad4d
57090a0
 
ce2ad4d
57090a0
67a401e
 
 
 
d0236fe
ce2ad4d
57090a0
ce2ad4d
57090a0
 
ce2ad4d
57090a0
ce2ad4d
 
 
0b46772
 
 
67a401e
0b46772
ce2ad4d
d0236fe
ce2ad4d
57090a0
ce2ad4d
a1ff09a
 
c4858e4
ce2ad4d
c4858e4
ce2ad4d
c4858e4
ce2ad4d
 
 
 
 
 
 
57090a0
ce2ad4d
 
57090a0
c4858e4
 
 
 
a1ff09a
 
ce2ad4d
 
a1ff09a
ce2ad4d
a1ff09a
c4858e4
 
ce2ad4d
57090a0
ce2ad4d
 
 
 
 
 
57090a0
ce2ad4d
c4858e4
 
 
ce2ad4d
 
 
 
 
a1ff09a
c4858e4
ce2ad4d
c4858e4
 
 
57090a0
ce2ad4d
57090a0
 
 
 
67a401e
 
ce2ad4d
c4858e4
ce2ad4d
 
 
 
 
 
 
 
 
2403335
ce2ad4d
57090a0
 
67a401e
ce2ad4d
 
 
 
 
 
 
 
 
d0236fe
ce2ad4d
 
67a401e
a1ff09a
ce2ad4d
 
 
 
 
 
 
 
 
 
 
67a401e
 
ce2ad4d
 
 
 
 
 
 
d0236fe
 
 
ce2ad4d
d0236fe
 
 
 
ce2ad4d
57090a0
ce2ad4d
 
 
57090a0
ce2ad4d
 
 
57090a0
551424e
ce2ad4d
 
57090a0
c4858e4
57090a0
 
ce2ad4d
 
57090a0
 
d0236fe
 
 
ce2ad4d
 
57090a0
d0236fe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
"""
LiquidGen Training Pipeline v2

Optimized for Colab free tier:
- Fast VAE encoding: batch=64 for 256px, batch=32 for 512px (~5x faster)
- Auto-limits large datasets (WikiArt capped at 10K by default)
- Latent pre-caching: train on pure tensors, no VAE during training
- Gradient checkpointing + auto batch size = no OOM
- ETA shown on every log line
- All datasets pure parquet, open SDXL VAE (no login)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler
import math
import os
import json
import time
from dataclasses import dataclass, asdict


DATASET_PRESETS = {
    "cartoon": {
        "name": "Norod78/cartoon-blip-captions",
        "config": "",
        "image_column": "image",
        "label_column": "",
        "num_classes": 0,
        "max_default": 0,
        "description": "~2.5K cartoon/anime, unconditional, 181MB — fast",
    },
    "flowers": {
        "name": "huggan/flowers-102-categories",
        "config": "",
        "image_column": "image",
        "label_column": "",
        "num_classes": 0,
        "max_default": 0,
        "description": "~8K flower photos, unconditional, 331MB",
    },
    "wikiart": {
        "name": "Artificio/WikiArt",
        "config": "",
        "image_column": "image",
        "label_column": "style",
        "num_classes": 0,
        "max_default": 10000,
        "description": "~105K paintings with styles (auto-capped to 10K for speed)",
    },
    "art_painting": {
        "name": "huggan/few-shot-art-painting",
        "config": "",
        "image_column": "image",
        "label_column": "",
        "num_classes": 0,
        "max_default": 0,
        "description": "~6K art paintings, unconditional, 511MB",
    },
}


def auto_batch_size(model_size, image_size, gpu_mem_gb):
    param_mem = {"small": 0.66, "base": 1.68, "large": 3.35}
    base = param_mem.get(model_size, 1.0)
    act_per_sample = {"small": {256: 0.02, 512: 0.07},
                      "base":  {256: 0.03, 512: 0.13},
                      "large": {256: 0.05, 512: 0.21}}
    per_sample = act_per_sample.get(model_size, {}).get(image_size, 0.1)
    available = gpu_mem_gb - base - 1.5
    bs = max(1, int(available / per_sample))
    if bs >= 32: return 32
    if bs >= 16: return 16
    if bs >= 8: return 8
    if bs >= 4: return 4
    return max(1, bs)


def _fmt_time(seconds):
    """Format seconds into human readable string."""
    if seconds < 60: return f"{seconds:.0f}s"
    if seconds < 3600: return f"{seconds/60:.1f}m"
    return f"{seconds/3600:.1f}h"


@dataclass
class TrainConfig:
    model_size: str = "small"
    num_classes: int = 0
    class_drop_prob: float = 0.1
    dataset_preset: str = "cartoon"
    image_size: int = 256
    max_images: int = 0
    vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
    vae_scaling_factor: float = 0.13025
    latent_channels: int = 4
    batch_size: int = 0
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    max_grad_norm: float = 2.0
    num_epochs: int = 100
    warmup_steps: int = 500
    ema_decay: float = 0.9999
    mixed_precision: bool = True
    gradient_checkpointing: bool = True
    min_timestep: float = 0.001
    max_timestep: float = 0.999
    output_dir: str = "./outputs"
    save_every_n_steps: int = 2000
    sample_every_n_steps: int = 500
    log_every_n_steps: int = 25
    num_sample_steps: int = 50
    cfg_scale: float = 2.0
    num_samples: int = 4
    seed: int = 42
    num_workers: int = 2
    compile_model: bool = False
    push_to_hub: bool = False
    hub_model_id: str = ""


def get_model_config(size, num_classes=0, class_drop_prob=0.1):
    configs = {
        "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,
                      expand_ratio=2.0, mlp_ratio=3.0),
        "base": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,
                     expand_ratio=2.0, mlp_ratio=4.0),
        "large": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,
                      expand_ratio=2.5, mlp_ratio=4.0),
    }
    cfg = configs[size]
    cfg["num_classes"] = num_classes
    cfg["class_drop_prob"] = class_drop_prob
    cfg["use_zigzag"] = True
    return cfg


class CachedLatentDataset(Dataset):
    def __init__(self, cache_path):
        data = torch.load(cache_path, map_location="cpu", weights_only=True)
        self.latents = data["latents"]
        self.labels = data.get("labels", None)
        print(f"Loaded {len(self.latents)} cached latents: {self.latents.shape}")
        if self.labels is not None and (self.labels >= 0).any():
            print(f"  {self.labels[self.labels >= 0].unique().shape[0]} classes")
    def __len__(self): return len(self.latents)
    def __getitem__(self, idx):
        return self.latents[idx], (self.labels[idx] if self.labels is not None else -1)


def precache_latents(config, cache_path=None):
    if cache_path is None:
        cache_path = os.path.join(config.output_dir, "cached_latents.pt")
    if os.path.exists(cache_path):
        print(f"Cache exists: {cache_path}")
        d = torch.load(cache_path, map_location="cpu", weights_only=True)
        print(f"  {d['latents'].shape[0]} latents {d['latents'].shape[1:]}")
        return cache_path

    os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Loading VAE: {config.vae_id}...")
    from diffusers import AutoencoderKL
    vae = AutoencoderKL.from_pretrained(config.vae_id, torch_dtype=torch.float16).to(device).eval()
    for p in vae.parameters(): p.requires_grad_(False)

    preset = DATASET_PRESETS[config.dataset_preset]
    print(f"Dataset: {preset['name']}")
    from datasets import load_dataset
    from torchvision import transforms

    ds_kwargs = {"split": "train"}
    if preset["config"]: ds_kwargs["name"] = preset["config"]
    dataset = load_dataset(preset["name"], **ds_kwargs)

    transform = transforms.Compose([
        transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.CenterCrop(config.image_size), transforms.ToTensor(),
    ])

    if config.max_images > 0:
        max_imgs = config.max_images
    elif preset.get("max_default", 0) > 0:
        max_imgs = preset["max_default"]
        print(f"  Auto-capping to {max_imgs} images (set max_images to override)")
    else:
        max_imgs = len(dataset)
    max_imgs = min(max_imgs, len(dataset))

    encode_bs = 64 if config.image_size <= 256 else 32
    print(f"  Encoding {max_imgs} images (batch={encode_bs})...")

    img_col, lbl_col = preset["image_column"], preset["label_column"]
    style_to_id = {}
    all_latents, all_labels = [], []
    batch_px, batch_lb = [], []
    count = 0
    t0 = time.time()

    for item in dataset:
        if count >= max_imgs: break
        img = item[img_col]
        if img.mode != "RGB": img = img.convert("RGB")
        batch_px.append(transform(img))
        if lbl_col and lbl_col in item:
            raw = item[lbl_col]
            if isinstance(raw, str):
                if raw not in style_to_id: style_to_id[raw] = len(style_to_id)
                batch_lb.append(style_to_id[raw])
            elif isinstance(raw, int): batch_lb.append(raw)
            else: batch_lb.append(-1)
        else: batch_lb.append(-1)
        count += 1
        if len(batch_px) >= encode_bs:
            with torch.no_grad():
                px = torch.stack(batch_px).to(device, dtype=torch.float16) * 2 - 1
                lat = vae.encode(px).latent_dist.sample() * config.vae_scaling_factor
                all_latents.append(lat.cpu().float())
            all_labels.extend(batch_lb); batch_px, batch_lb = [], []
            elapsed = time.time() - t0
            speed = count / elapsed
            eta = (max_imgs - count) / speed if speed > 0 else 0
            if count % (encode_bs * 4) == 0:
                print(f"  {count}/{max_imgs} | {speed:.0f} img/s | ETA {_fmt_time(eta)}")

    if batch_px:
        with torch.no_grad():
            px = torch.stack(batch_px).to(device, dtype=torch.float16) * 2 - 1
            lat = vae.encode(px).latent_dist.sample() * config.vae_scaling_factor
            all_latents.append(lat.cpu().float())
        all_labels.extend(batch_lb)

    all_latents = torch.cat(all_latents, dim=0)
    all_labels = torch.tensor(all_labels, dtype=torch.long)
    save_data = {"latents": all_latents, "labels": all_labels}
    if style_to_id:
        save_data["style_to_id"] = style_to_id
        print(f"  {len(style_to_id)} style classes")
    torch.save(save_data, cache_path)
    mb = os.path.getsize(cache_path) / 1024**2
    print(f"Cached {count} latents -> {cache_path} ({mb:.0f}MB, {_fmt_time(time.time()-t0)})")
    del vae
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    return cache_path


class EMAModel:
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
    @torch.no_grad()
    def update(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.shadow:
                self.shadow[n].mul_(self.decay).add_(p.data, alpha=1 - self.decay)
    def apply(self, model):
        self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])
    def restore(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n])
        self.backup = {}


class FlowMatchingScheduler:
    def __init__(self, min_t=0.001, max_t=0.999):
        self.min_t, self.max_t = min_t, max_t
    def sample_timesteps(self, bs, dev):
        return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t
    def add_noise(self, x0, noise, t):
        t = t.view(-1, 1, 1, 1); return (1 - t) * x0 + t * noise
    def get_velocity_target(self, x0, noise):
        return noise - x0
    @torch.no_grad()
    def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):
        model.eval(); x = torch.randn(shape, device=dev); dt = 1.0 / num_steps
        for tv in torch.linspace(1.0, dt, num_steps, device=dev):
            t = torch.full((shape[0],), tv.item(), device=dev)
            with torch.amp.autocast("cuda"):
                if cfg > 1.0 and labels is not None:
                    vc = model(x, t, labels); vu = model(x, t, torch.zeros_like(labels))
                    v = vu + cfg * (vc - vu)
                else: v = model(x, t, labels)
            x = x - dt * v.float()
        return x


def cosine_schedule(opt, warmup, total):
    def lr(s):
        if s < warmup: return s / max(1, warmup)
        return max(0, 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(1, total - warmup))))
    return torch.optim.lr_scheduler.LambdaLR(opt, lr)


def train(config):
    from model import LiquidGen
    torch.manual_seed(config.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gpu_mem = 0
    if torch.cuda.is_available():
        gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1024**3
        print(f"GPU: {torch.cuda.get_device_name(0)} ({gpu_mem:.1f} GB)")

    if config.batch_size <= 0:
        config.batch_size = auto_batch_size(config.model_size, config.image_size, gpu_mem) if gpu_mem > 0 else 4
        print(f"Auto batch: {config.batch_size}")

    os.makedirs(config.output_dir, exist_ok=True)
    os.makedirs(f"{config.output_dir}/samples", exist_ok=True)
    os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True)

    cache_path = precache_latents(config)
    train_ds = CachedLatentDataset(cache_path)
    train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
                          num_workers=config.num_workers, pin_memory=True, drop_last=True)

    mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
    mcfg["in_channels"] = config.latent_channels
    model = LiquidGen(**mcfg).to(device)
    if config.gradient_checkpointing:
        model.enable_gradient_checkpointing()
    print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M (ckpt={'ON' if config.gradient_checkpointing else 'OFF'})")

    opt = torch.optim.AdamW(model.parameters(), lr=config.learning_rate,
                            weight_decay=config.weight_decay, betas=(0.9, 0.999))
    total_steps = len(train_dl) * config.num_epochs // config.gradient_accumulation_steps
    sched = cosine_schedule(opt, config.warmup_steps, total_steps)
    ema = EMAModel(model, config.ema_decay)
    scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available())
    fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
    lat_size = config.image_size // 8
    print(f"Steps: {total_steps} | Batch: {config.batch_size} | Epochs: {config.num_epochs}")

    gs = 0; la = 0.0; vae = None; vae_loaded = False
    print(f"\nTraining!\n")
    t_start = time.time()

    for epoch in range(config.num_epochs):
        model.train(); et = time.time()
        for bi, (lats, lbls) in enumerate(train_dl):
            lats = lats.to(device)
            lbls = lbls.to(device) if config.num_classes > 0 else None
            t = fm.sample_timesteps(lats.shape[0], device)
            noise = torch.randn_like(lats)
            xt = fm.add_noise(lats, noise, t)
            vtgt = fm.get_velocity_target(lats, noise)
            with autocast("cuda", enabled=config.mixed_precision and torch.cuda.is_available()):
                loss = F.mse_loss(model(xt, t, lbls), vtgt) / config.gradient_accumulation_steps
            scaler.scale(loss).backward(); la += loss.item()
            if (bi + 1) % config.gradient_accumulation_steps == 0:
                scaler.unscale_(opt)
                gn = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step()
                ema.update(model); gs += 1
                if gs % config.log_every_n_steps == 0:
                    al = la / config.log_every_n_steps
                    elapsed = time.time() - t_start
                    sps = gs / max(elapsed, 1)
                    remaining = (total_steps - gs) / sps if sps > 0 else 0
                    vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
                    pct = gs / total_steps * 100
                    print(f"step={gs:>6d}/{total_steps} ({pct:.0f}%) | ep={epoch} | "
                          f"loss={al:.4f} | gn={gn:.2f} | lr={opt.param_groups[0]['lr']:.2e} | "
                          f"vram={vram:.1f}G | {sps:.1f} st/s | ETA {_fmt_time(remaining)}")
                    la = 0.0
                    if math.isnan(al) or al > 50: print("Diverged!"); return
                if gs % config.sample_every_n_steps == 0:
                    if not vae_loaded:
                        from diffusers import AutoencoderKL
                        vae = AutoencoderKL.from_pretrained(config.vae_id, torch_dtype=torch.float16).to(device).eval()
                        for p in vae.parameters(): p.requires_grad_(False)
                        vae_loaded = True
                    ema.apply(model); model.eval()
                    sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,), device=device) if config.num_classes > 0 else None
                    samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
                                     device, config.num_sample_steps, sl, config.cfg_scale)
                    with torch.no_grad():
                        imgs = ((vae.decode(samp.half() / config.vae_scaling_factor).sample + 1) / 2).clamp(0, 1).float()
                    from torchvision.utils import save_image
                    save_image(imgs, f"{config.output_dir}/samples/step_{gs:07d}.png", nrow=2)
                    print(f"  Saved samples"); ema.restore(model); model.train()
                if gs % config.save_every_n_steps == 0:
                    torch.save({"model": model.state_dict(), "ema": ema.shadow,
                                "optimizer": opt.state_dict(), "step": gs, "model_config": mcfg},
                               f"{config.output_dir}/checkpoints/step_{gs:07d}.pt")
        ep_time = time.time() - et
        ep_eta = ep_time * (config.num_epochs - epoch - 1)
        print(f"Epoch {epoch}/{config.num_epochs} done | {_fmt_time(ep_time)} | ETA {_fmt_time(ep_eta)}\n")

    final = f"{config.output_dir}/checkpoints/final.pt"
    torch.save({"model": model.state_dict(), "ema": ema.shadow, "model_config": mcfg, "step": gs}, final)
    total_time = time.time() - t_start
    print(f"\nDone! {gs} steps in {_fmt_time(total_time)} -> {final}")