| import os |
| from argparse import ArgumentParser |
|
|
| from omegaconf import OmegaConf |
| import torch |
| from torch.utils.data import DataLoader |
| from torchvision.utils import make_grid |
| from accelerate import Accelerator |
| from accelerate.utils import set_seed |
| from einops import rearrange |
| from tqdm import tqdm |
| from torch.utils.tensorboard import SummaryWriter |
| from PIL import Image, ImageDraw, ImageFont |
| import numpy as np |
|
|
| from model import ControlLDM, SwinIR, Diffusion |
| from utils.common import instantiate_from_config |
| from utils.sampler import SpacedSampler |
|
|
|
|
| def log_txt_as_img(wh, xc): |
| |
| |
| b = len(xc) |
| txts = list() |
| for bi in range(b): |
| txt = Image.new("RGB", wh, color="white") |
| draw = ImageDraw.Draw(txt) |
| |
| font = ImageFont.load_default() |
| nc = int(40 * (wh[0] / 256)) |
| lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) |
|
|
| try: |
| draw.text((0, 0), lines, fill="black", font=font) |
| except UnicodeEncodeError: |
| print("Cant encode string for logging. Skipping.") |
|
|
| txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 |
| txts.append(txt) |
| txts = np.stack(txts) |
| txts = torch.tensor(txts) |
| return txts |
|
|
|
|
| def main(args) -> None: |
| |
| accelerator = Accelerator(split_batches=True) |
| set_seed(231) |
| device = accelerator.device |
| cfg = OmegaConf.load(args.config) |
|
|
| |
| if accelerator.is_local_main_process: |
| exp_dir = cfg.train.exp_dir |
| os.makedirs(exp_dir, exist_ok=True) |
| ckpt_dir = os.path.join(exp_dir, "checkpoints") |
| os.makedirs(ckpt_dir, exist_ok=True) |
| print(f"Experiment directory created at {exp_dir}") |
|
|
| |
| cldm: ControlLDM = instantiate_from_config(cfg.model.cldm) |
| sd = torch.load(cfg.train.sd_path, map_location="cpu")["state_dict"] |
| unused = cldm.load_pretrained_sd(sd) |
| if accelerator.is_local_main_process: |
| print(f"strictly load pretrained SD weight from {cfg.train.sd_path}\n" |
| f"unused weights: {unused}") |
| |
| if cfg.train.resume: |
| cldm.load_controlnet_from_ckpt(torch.load(cfg.train.resume, map_location="cpu")) |
| if accelerator.is_local_main_process: |
| print(f"strictly load controlnet weight from checkpoint: {cfg.train.resume}") |
| else: |
| init_with_new_zero, init_with_scratch = cldm.load_controlnet_from_unet() |
| if accelerator.is_local_main_process: |
| print(f"strictly load controlnet weight from pretrained SD\n" |
| f"weights initialized with newly added zeros: {init_with_new_zero}\n" |
| f"weights initialized from scratch: {init_with_scratch}") |
| |
| swinir: SwinIR = instantiate_from_config(cfg.model.swinir) |
| sd = { |
| (k[len("module."):] if k.startswith("module.") else k): v |
| for k, v in torch.load(cfg.train.swinir_path, map_location="cpu").items() |
| } |
| swinir.load_state_dict(sd, strict=True) |
| for p in swinir.parameters(): |
| p.requires_grad = False |
| if accelerator.is_local_main_process: |
| print(f"load SwinIR from {cfg.train.swinir_path}") |
| |
| diffusion: Diffusion = instantiate_from_config(cfg.model.diffusion) |
| |
| |
| opt = torch.optim.AdamW(cldm.controlnet.parameters(), lr=cfg.train.learning_rate) |
| |
| |
| dataset = instantiate_from_config(cfg.dataset.train) |
| loader = DataLoader( |
| dataset=dataset, batch_size=cfg.train.batch_size, |
| num_workers=cfg.train.num_workers, |
| shuffle=True, drop_last=True |
| ) |
| if accelerator.is_local_main_process: |
| print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}") |
|
|
| |
| cldm.train().to(device) |
| swinir.eval().to(device) |
| diffusion.to(device) |
| cldm, opt, loader = accelerator.prepare(cldm, opt, loader) |
| pure_cldm: ControlLDM = accelerator.unwrap_model(cldm) |
| |
| |
| global_step = 0 |
| max_steps = cfg.train.train_steps |
| step_loss = [] |
| epoch = 0 |
| epoch_loss = [] |
| sampler = SpacedSampler(diffusion.betas) |
| if accelerator.is_local_main_process: |
| writer = SummaryWriter(exp_dir) |
| print(f"Training for {max_steps} steps...") |
| |
| while global_step < max_steps: |
| pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", total=len(loader)) |
| for gt, lq, prompt in loader: |
| gt = rearrange(gt, "b h w c -> b c h w").contiguous().float().to(device) |
| lq = rearrange(lq, "b h w c -> b c h w").contiguous().float().to(device) |
| with torch.no_grad(): |
| z_0 = pure_cldm.vae_encode(gt) |
| clean = swinir(lq) |
| cond = pure_cldm.prepare_condition(clean, prompt) |
| t = torch.randint(0, diffusion.num_timesteps, (z_0.shape[0],), device=device) |
| |
| loss = diffusion.p_losses(cldm, z_0, t, cond) |
| opt.zero_grad() |
| accelerator.backward(loss) |
| opt.step() |
|
|
| accelerator.wait_for_everyone() |
|
|
| global_step += 1 |
| step_loss.append(loss.item()) |
| epoch_loss.append(loss.item()) |
| pbar.update(1) |
| pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}") |
|
|
| |
| if global_step % cfg.train.log_every == 0 and global_step > 0: |
| |
| avg_loss = accelerator.gather(torch.tensor(step_loss, device=device).unsqueeze(0)).mean().item() |
| step_loss.clear() |
| if accelerator.is_local_main_process: |
| writer.add_scalar("loss/loss_simple_step", avg_loss, global_step) |
|
|
| |
| if global_step % cfg.train.ckpt_every == 0 and global_step > 0: |
| if accelerator.is_local_main_process: |
| checkpoint = pure_cldm.controlnet.state_dict() |
| ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt" |
| torch.save(checkpoint, ckpt_path) |
|
|
| if global_step % cfg.train.image_every == 0 or global_step == 1: |
| N = 12 |
| log_clean = clean[:N] |
| log_cond = {k:v[:N] for k, v in cond.items()} |
| log_gt, log_lq = gt[:N], lq[:N] |
| log_prompt = prompt[:N] |
| cldm.eval() |
| with torch.no_grad(): |
| z = sampler.sample( |
| model=cldm, device=device, steps=50, batch_size=len(log_gt), x_size=z_0.shape[1:], |
| cond=log_cond, uncond=None, cfg_scale=1.0, x_T=None, |
| progress=accelerator.is_local_main_process, progress_leave=False |
| ) |
| if accelerator.is_local_main_process: |
| for tag, image in [ |
| ("image/samples", (pure_cldm.vae_decode(z) + 1) / 2), |
| ("image/gt", (log_gt + 1) / 2), |
| ("image/lq", log_lq), |
| ("image/condition", log_clean), |
| ("image/condition_decoded", (pure_cldm.vae_decode(log_cond["c_img"]) + 1) / 2), |
| ("image/prompt", (log_txt_as_img((512, 512), log_prompt) + 1) / 2) |
| ]: |
| writer.add_image(tag, make_grid(image, nrow=4), global_step) |
| cldm.train() |
| accelerator.wait_for_everyone() |
| if global_step == max_steps: |
| break |
| |
| pbar.close() |
| epoch += 1 |
| avg_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device).unsqueeze(0)).mean().item() |
| epoch_loss.clear() |
| if accelerator.is_local_main_process: |
| writer.add_scalar("loss/loss_simple_epoch", avg_epoch_loss, global_step) |
|
|
| if accelerator.is_local_main_process: |
| print("done!") |
| writer.close() |
|
|
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| args = parser.parse_args() |
| main(args) |
|
|