| import argparse |
| import math |
| import os |
| import pickle |
|
|
| import torch |
| import torchvision |
| from torch import optim |
| from tqdm import tqdm |
|
|
| from StyleCLIP.criteria.clip_loss import CLIPLoss |
| from StyleCLIP.models.stylegan2.model import Generator |
| import clip |
| from StyleCLIP.utils import ensure_checkpoint_exists |
|
|
|
|
| def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): |
| lr_ramp = min(1, (1 - t) / rampdown) |
| lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) |
| lr_ramp = lr_ramp * min(1, t / rampup) |
|
|
| return initial_lr * lr_ramp |
|
|
|
|
| def main(args, use_old_G): |
| ensure_checkpoint_exists(args.ckpt) |
| text_inputs = torch.cat([clip.tokenize(args.description)]).cuda() |
| os.makedirs(args.results_dir, exist_ok=True) |
| new_generator_path = f'/disk2/danielroich/Sandbox/stylegan2_ada_pytorch/checkpoints/model_{args.run_id}_{args.image_name}.pt' |
| old_generator_path = '/disk2/danielroich/Sandbox/pretrained_models/ffhq.pkl' |
|
|
| if not use_old_G: |
| with open(new_generator_path, 'rb') as f: |
| G = torch.load(f).cuda().eval() |
| else: |
| with open(old_generator_path, 'rb') as f: |
| G = pickle.load(f)['G_ema'].cuda().eval() |
|
|
| if args.latent_path: |
| latent_code_init = torch.load(args.latent_path).cuda() |
| elif args.mode == "edit": |
| latent_code_init_not_trunc = torch.randn(1, 512).cuda() |
| with torch.no_grad(): |
| latent_code_init = G.mapping(latent_code_init_not_trunc, None) |
|
|
| latent = latent_code_init.detach().clone() |
| latent.requires_grad = True |
|
|
| clip_loss = CLIPLoss(args) |
|
|
| optimizer = optim.Adam([latent], lr=args.lr) |
|
|
| pbar = tqdm(range(args.step)) |
|
|
| for i in pbar: |
| t = i / args.step |
| lr = get_lr(t, args.lr) |
| optimizer.param_groups[0]["lr"] = lr |
|
|
| img_gen = G.synthesis(latent, noise_mode='const') |
|
|
| c_loss = clip_loss(img_gen, text_inputs) |
|
|
| if args.mode == "edit": |
| l2_loss = ((latent_code_init - latent) ** 2).sum() |
| loss = c_loss + args.l2_lambda * l2_loss |
| else: |
| loss = c_loss |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| pbar.set_description( |
| ( |
| f"loss: {loss.item():.4f};" |
| ) |
| ) |
| if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0: |
| with torch.no_grad(): |
| img_gen = G.synthesis(latent, noise_mode='const') |
|
|
| torchvision.utils.save_image(img_gen, |
| f"/disk2/danielroich/Sandbox/StyleCLIP/results/inference_results/{str(i).zfill(5)}.png", |
| normalize=True, range=(-1, 1)) |
|
|
| if args.mode == "edit": |
| with torch.no_grad(): |
| img_orig = G.synthesis(latent_code_init, noise_mode='const') |
|
|
| final_result = torch.cat([img_orig, img_gen]) |
| else: |
| final_result = img_gen |
|
|
| return final_result |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--description", type=str, default="a person with purple hair", |
| help="the text that guides the editing/generation") |
| parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", |
| help="pretrained StyleGAN2 weights") |
| parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution") |
| parser.add_argument("--lr_rampup", type=float, default=0.05) |
| parser.add_argument("--lr", type=float, default=0.1) |
| parser.add_argument("--step", type=int, default=300, help="number of optimization steps") |
| parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], |
| help="choose between edit an image an generate a free one") |
| parser.add_argument("--l2_lambda", type=float, default=0.008, |
| help="weight of the latent distance (used for editing only)") |
| parser.add_argument("--latent_path", type=str, default=None, |
| help="starts the optimization from the given latent code if provided. Otherwose, starts from" |
| "the mean latent in a free generation, and from a random one in editing. " |
| "Expects a .pt format") |
| parser.add_argument("--truncation", type=float, default=0.7, |
| help="used only for the initial latent vector, and only when a latent code path is" |
| "not provided") |
| parser.add_argument("--save_intermediate_image_every", type=int, default=20, |
| help="if > 0 then saves intermidate results during the optimization") |
| parser.add_argument("--results_dir", type=str, default="results") |
|
|
| args = parser.parse_args() |
|
|
| result_image = main(args) |
|
|
| torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), |
| normalize=True, scale_each=True, range=(-1, 1)) |
|
|