Upload train.py
Browse files
train.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main training entry point for LuminaRS."""
|
| 2 |
+
import os, torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 6 |
+
from diffusers import AutoencoderKL
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from luminars.config import LuminaRSConfig
|
| 9 |
+
from luminars.model import LuminaRS
|
| 10 |
+
from luminars.flow_loss import flow_loss
|
| 11 |
+
from luminars.train import setup_stage1, setup_stage2, setup_stage3, freeze, count_trainable
|
| 12 |
+
|
| 13 |
+
STYLE_LABELS = ["Abstract","Action Painting","Analytical Cubism","Art Nouveau","Baroque","Color Field Painting","Contemporary Realism","Cubism","Early Renaissance","Expressionism","Fauvism","High Renaissance","Impressionism","Mannerism","Minimalism","Naive Art","Neoclassicism","Northern Renaissance","Pointillism","Pop Art","Post-Impressionism","Realism","Rococo","Romanticism","Symbolism","Synthetic Cubism","Ukiyo-e"]
|
| 14 |
+
|
| 15 |
+
def get_dataloader(batch_size=4, num_workers=2):
|
| 16 |
+
ds = load_dataset("huggan/wikiart", split="train")
|
| 17 |
+
tfm = transforms.Compose([transforms.Resize(512), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Normalize([.5]*3, [.5]*3)])
|
| 18 |
+
class ArtDS(torch.utils.data.Dataset):
|
| 19 |
+
def __init__(self, ds, tfm): self.ds=ds; self.tfm=tfm
|
| 20 |
+
def __len__(self): return len(self.ds)
|
| 21 |
+
def __getitem__(self, i):
|
| 22 |
+
it = self.ds[i]; img = self.tfm(it["image"].convert("RGB")); cap = f"a painting in {STYLE_LABELS[it['style']]} style"
|
| 23 |
+
return img, cap
|
| 24 |
+
return DataLoader(ArtDS(ds,tfm), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
|
| 25 |
+
|
| 26 |
+
def train(stage=1, epochs=10, lr=1e-4, batch_size=4, save_dir="./checkpoints"):
|
| 27 |
+
os.makedirs(save_dir, exist_ok=True); cfg = LuminaRSConfig()
|
| 28 |
+
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
vae = AutoencoderKL.from_pretrained(cfg.vae_pretrained).to(dev); freeze(vae)
|
| 30 |
+
tok = CLIPTokenizer.from_pretrained(cfg.clip_pretrained)
|
| 31 |
+
clip = CLIPTextModel.from_pretrained(cfg.clip_pretrained).to(dev); freeze(clip)
|
| 32 |
+
model = LuminaRS(cfg).to(dev)
|
| 33 |
+
if stage==1: setup_stage1(model)
|
| 34 |
+
elif stage==2: setup_stage2(model); lr*=0.1
|
| 35 |
+
elif stage==3: setup_stage3(model); lr*=0.01
|
| 36 |
+
print(f"Stage {stage}: {count_trainable(model)/1e6:.1f}M train, lr={lr:.1e}")
|
| 37 |
+
opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=0.01)
|
| 38 |
+
dl = get_dataloader(batch_size); scaler = torch.amp.GradScaler("cuda"); model.train()
|
| 39 |
+
for ep in range(epochs):
|
| 40 |
+
for step, (imgs, caps) in enumerate(dl):
|
| 41 |
+
imgs=imgs.to(dev); tks=tok(caps,padding="max_length",max_length=77,truncation=True,return_tensors="pt").input_ids.to(dev)
|
| 42 |
+
with torch.no_grad(): z=vae.encode(imgs).latent_dist.sample()*vae.config.scaling_factor; te=clip(tks).last_hidden_state
|
| 43 |
+
with torch.amp.autocast("cuda"): loss=flow_loss(model,z,te)
|
| 44 |
+
scaler.scale(loss).backward(); scaler.unscale_(opt)
|
| 45 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
|
| 46 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 47 |
+
if step%50==0: print(f"[S{stage} E{ep} {step}] loss={loss.item():.4f}")
|
| 48 |
+
torch.save(model.state_dict(), f"{save_dir}/s{stage}_e{ep}.pt")
|
| 49 |
+
return model
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__": train(stage=1, epochs=5, lr=1e-4)
|