LuminaRS / train.py
asdf98's picture
Upload train.py
02041e2 verified
"""Main training entry point for LuminaRS."""
import os, torch
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from datasets import load_dataset
from luminars.config import LuminaRSConfig
from luminars.model import LuminaRS
from luminars.flow_loss import flow_loss
from luminars.train import setup_stage1, setup_stage2, setup_stage3, freeze, count_trainable
STYLE_LABELS = ["Abstract","Action Painting","Analytical Cubism","Art Nouveau","Baroque","Color Field Painting","Contemporary Realism","Cubism","Early Renaissance","Expressionism","Fauvism","High Renaissance","Impressionism","Mannerism","Minimalism","Naive Art","Neoclassicism","Northern Renaissance","Pointillism","Pop Art","Post-Impressionism","Realism","Rococo","Romanticism","Symbolism","Synthetic Cubism","Ukiyo-e"]
def get_dataloader(batch_size=4, num_workers=2):
ds = load_dataset("huggan/wikiart", split="train")
tfm = transforms.Compose([transforms.Resize(512), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Normalize([.5]*3, [.5]*3)])
class ArtDS(torch.utils.data.Dataset):
def __init__(self, ds, tfm): self.ds=ds; self.tfm=tfm
def __len__(self): return len(self.ds)
def __getitem__(self, i):
it = self.ds[i]; img = self.tfm(it["image"].convert("RGB")); cap = f"a painting in {STYLE_LABELS[it['style']]} style"
return img, cap
return DataLoader(ArtDS(ds,tfm), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
def train(stage=1, epochs=10, lr=1e-4, batch_size=4, save_dir="./checkpoints"):
os.makedirs(save_dir, exist_ok=True); cfg = LuminaRSConfig()
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained(cfg.vae_pretrained).to(dev); freeze(vae)
tok = CLIPTokenizer.from_pretrained(cfg.clip_pretrained)
clip = CLIPTextModel.from_pretrained(cfg.clip_pretrained).to(dev); freeze(clip)
model = LuminaRS(cfg).to(dev)
if stage==1: setup_stage1(model)
elif stage==2: setup_stage2(model); lr*=0.1
elif stage==3: setup_stage3(model); lr*=0.01
print(f"Stage {stage}: {count_trainable(model)/1e6:.1f}M train, lr={lr:.1e}")
opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=0.01)
dl = get_dataloader(batch_size); scaler = torch.amp.GradScaler("cuda"); model.train()
for ep in range(epochs):
for step, (imgs, caps) in enumerate(dl):
imgs=imgs.to(dev); tks=tok(caps,padding="max_length",max_length=77,truncation=True,return_tensors="pt").input_ids.to(dev)
with torch.no_grad(): z=vae.encode(imgs).latent_dist.sample()*vae.config.scaling_factor; te=clip(tks).last_hidden_state
with torch.amp.autocast("cuda"): loss=flow_loss(model,z,te)
scaler.scale(loss).backward(); scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
scaler.step(opt); scaler.update(); opt.zero_grad()
if step%50==0: print(f"[S{stage} E{ep} {step}] loss={loss.item():.4f}")
torch.save(model.state_dict(), f"{save_dir}/s{stage}_e{ep}.pt")
return model
if __name__ == "__main__": train(stage=1, epochs=5, lr=1e-4)