| """Main training entry point for LuminaRS.""" |
| import os, torch |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from diffusers import AutoencoderKL |
| from datasets import load_dataset |
| from luminars.config import LuminaRSConfig |
| from luminars.model import LuminaRS |
| from luminars.flow_loss import flow_loss |
| from luminars.train import setup_stage1, setup_stage2, setup_stage3, freeze, count_trainable |
|
|
| STYLE_LABELS = ["Abstract","Action Painting","Analytical Cubism","Art Nouveau","Baroque","Color Field Painting","Contemporary Realism","Cubism","Early Renaissance","Expressionism","Fauvism","High Renaissance","Impressionism","Mannerism","Minimalism","Naive Art","Neoclassicism","Northern Renaissance","Pointillism","Pop Art","Post-Impressionism","Realism","Rococo","Romanticism","Symbolism","Synthetic Cubism","Ukiyo-e"] |
|
|
| def get_dataloader(batch_size=4, num_workers=2): |
| ds = load_dataset("huggan/wikiart", split="train") |
| tfm = transforms.Compose([transforms.Resize(512), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Normalize([.5]*3, [.5]*3)]) |
| class ArtDS(torch.utils.data.Dataset): |
| def __init__(self, ds, tfm): self.ds=ds; self.tfm=tfm |
| def __len__(self): return len(self.ds) |
| def __getitem__(self, i): |
| it = self.ds[i]; img = self.tfm(it["image"].convert("RGB")); cap = f"a painting in {STYLE_LABELS[it['style']]} style" |
| return img, cap |
| return DataLoader(ArtDS(ds,tfm), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) |
|
|
| def train(stage=1, epochs=10, lr=1e-4, batch_size=4, save_dir="./checkpoints"): |
| os.makedirs(save_dir, exist_ok=True); cfg = LuminaRSConfig() |
| dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| vae = AutoencoderKL.from_pretrained(cfg.vae_pretrained).to(dev); freeze(vae) |
| tok = CLIPTokenizer.from_pretrained(cfg.clip_pretrained) |
| clip = CLIPTextModel.from_pretrained(cfg.clip_pretrained).to(dev); freeze(clip) |
| model = LuminaRS(cfg).to(dev) |
| if stage==1: setup_stage1(model) |
| elif stage==2: setup_stage2(model); lr*=0.1 |
| elif stage==3: setup_stage3(model); lr*=0.01 |
| print(f"Stage {stage}: {count_trainable(model)/1e6:.1f}M train, lr={lr:.1e}") |
| opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=0.01) |
| dl = get_dataloader(batch_size); scaler = torch.amp.GradScaler("cuda"); model.train() |
| for ep in range(epochs): |
| for step, (imgs, caps) in enumerate(dl): |
| imgs=imgs.to(dev); tks=tok(caps,padding="max_length",max_length=77,truncation=True,return_tensors="pt").input_ids.to(dev) |
| with torch.no_grad(): z=vae.encode(imgs).latent_dist.sample()*vae.config.scaling_factor; te=clip(tks).last_hidden_state |
| with torch.amp.autocast("cuda"): loss=flow_loss(model,z,te) |
| scaler.scale(loss).backward(); scaler.unscale_(opt) |
| torch.nn.utils.clip_grad_norm_(model.parameters(),1.0) |
| scaler.step(opt); scaler.update(); opt.zero_grad() |
| if step%50==0: print(f"[S{stage} E{ep} {step}] loss={loss.item():.4f}") |
| torch.save(model.state_dict(), f"{save_dir}/s{stage}_e{ep}.pt") |
| return model |
|
|
| if __name__ == "__main__": train(stage=1, epochs=5, lr=1e-4) |
|
|