| """ |
| Trains a modified Resnet to generate approximate dlatents using examples from a trained StyleGAN. |
| Props to @SimJeg on GitHub for the original code this is based on, from this thread: https://github.com/Puzer/stylegan-encoder/issues/1#issuecomment-490469454 |
| """ |
| import os |
| import math |
| import numpy as np |
| import pickle |
| import cv2 |
| import argparse |
|
|
| import dnnlib |
| import config |
| import dnnlib.tflib as tflib |
|
|
| import tensorflow |
| import keras |
| import keras.backend as K |
|
|
| from keras.applications.resnet50 import preprocess_input |
| from keras.layers import Input, LocallyConnected1D, Reshape, Permute, Conv2D, Add |
| from keras.models import Model, load_model |
|
|
| def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=16, truncation=0.7): |
| """ |
| Generates a dataset of 'n' images of shape ('size', 'size', 3) with random seed 'seed' |
| along with their dlatent vectors W of shape ('n', 512) |
| |
| These datasets can serve to train an inverse mapping from X to W as well as explore the latent space |
| |
| More variation added to latents; also, negative truncation added to balance these examples. |
| """ |
|
|
| n = n // 2 |
| model_scale = int(2*(math.log(model_res,2)-1)) |
|
|
| Gs = load_Gs() |
| if (model_scale % 3 == 0): |
| mod_l = 3 |
| else: |
| mod_l = 2 |
| if seed is not None: |
| b = bool(np.random.RandomState(seed).randint(2)) |
| Z = np.random.RandomState(seed).randn(n*mod_l, Gs.input_shape[1]) |
| else: |
| b = bool(np.random.randint(2)) |
| Z = np.random.randn(n*mod_l, Gs.input_shape[1]) |
| if b: |
| mod_l = model_scale // 2 |
| mod_r = model_scale // mod_l |
| if seed is not None: |
| Z = np.random.RandomState(seed).randn(n*mod_l, Gs.input_shape[1]) |
| else: |
| Z = np.random.randn(n*mod_l, Gs.input_shape[1]) |
| W = Gs.components.mapping.run(Z, None, minibatch_size=minibatch_size) |
| dlatent_avg = Gs.get_var('dlatent_avg') |
| W = (W[np.newaxis] - dlatent_avg) * np.reshape([truncation, -truncation], [-1, 1, 1, 1]) + dlatent_avg |
| W = np.append(W[0], W[1], axis=0) |
| W = W[:, :mod_r] |
| W = W.reshape((n*2, model_scale, 512)) |
| X = Gs.components.synthesis.run(W, randomize_noise=False, minibatch_size=minibatch_size, print_progress=True, |
| output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)) |
| X = np.array([cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in X]) |
| |
| X = preprocess_input(X) |
| return W, X |
|
|
| def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=16, truncation=0.7): |
| """ |
| Use generate_dataset_main() as a helper function. |
| Divides requests into batches to save memory. |
| """ |
| batch_size = 16 |
| inc = n//batch_size |
| left = n-((batch_size-1)*inc) |
| W, X = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size, truncation) |
| for i in range(batch_size-2): |
| aW, aX = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size, truncation) |
| W = np.append(W, aW, axis=0) |
| aW = None |
| X = np.append(X, aX, axis=0) |
| aX = None |
| aW, aX = generate_dataset_main(left, save_path, seed, model_res, image_size, minibatch_size, truncation) |
| W = np.append(W, aW, axis=0) |
| aW = None |
| X = np.append(X, aX, axis=0) |
| aX = None |
|
|
| if save_path is not None: |
| prefix = '_{}_{}'.format(seed, n) |
| np.save(os.path.join(os.path.join(save_path, 'W' + prefix)), W) |
| np.save(os.path.join(os.path.join(save_path, 'X' + prefix)), X) |
|
|
| return W, X |
|
|
| def is_square(n): |
| return (n == int(math.sqrt(n) + 0.5)**2) |
| |
| def get_resnet_model(save_path, model_res=1024, image_size=256, depth=2, size=0, activation='elu', loss='logcosh', optimizer='adam'): |
| |
| if os.path.exists(save_path): |
| print('Loading model') |
| return load_model(save_path) |
|
|
| print('Building model') |
| model_scale = int(2*(math.log(model_res,2)-1)) |
|
|
| if size <= 0: |
| from keras.applications.resnet50 import ResNet50 |
| resnet = ResNet50(include_top=False, pooling=None, weights='imagenet', input_shape=(image_size, image_size, 3)) |
| else: |
| from keras_applications.resnet_v2 import ResNet50V2, ResNet101V2, ResNet152V2 |
| if size == 1: |
| resnet = ResNet50V2(include_top=False, pooling=None, weights='imagenet', input_shape=(image_size, image_size, 3), backend = keras.backend, layers = keras.layers, models = keras.models, utils = keras.utils) |
| if size == 2: |
| resnet = ResNet101V2(include_top=False, pooling=None, weights='imagenet', input_shape=(image_size, image_size, 3), backend = keras.backend, layers = keras.layers, models = keras.models, utils = keras.utils) |
| if size >= 3: |
| resnet = ResNet152V2(include_top=False, pooling=None, weights='imagenet', input_shape=(image_size, image_size, 3), backend = keras.backend, layers = keras.layers, models = keras.models, utils = keras.utils) |
|
|
| layer_size = model_scale*8*8*8 |
| if is_square(layer_size): |
| layer_l = int(math.sqrt(layer_size)+0.5) |
| layer_r = layer_l |
| else: |
| layer_m = math.log(math.sqrt(layer_size),2) |
| layer_l = 2**math.ceil(layer_m) |
| layer_r = layer_size // layer_l |
| layer_l = int(layer_l) |
| layer_r = int(layer_r) |
|
|
| x_init = None |
| inp = Input(shape=(image_size, image_size, 3)) |
| x = resnet(inp) |
|
|
| if (depth < 0): |
| depth = 1 |
|
|
| if (size <= 1): |
| if (size <= 0): |
| x = Conv2D(model_scale*8, 1, activation=activation)(x) |
| x = Reshape((layer_r, layer_l))(x) |
| else: |
| x = Conv2D(model_scale*8*4, 1, activation=activation)(x) |
| x = Reshape((layer_r*2, layer_l*2))(x) |
| else: |
| if (size == 2): |
| x = Conv2D(1024, 1, activation=activation)(x) |
| x = Reshape((256, 256))(x) |
| else: |
| x = Reshape((256, 512))(x) |
|
|
| while (depth > 0): |
| x = LocallyConnected1D(layer_r, 1, activation=activation)(x) |
| x = Permute((2, 1))(x) |
| x = LocallyConnected1D(layer_l, 1, activation=activation)(x) |
| x = Permute((2, 1))(x) |
| if x_init is not None: |
| x = Add()([x, x_init]) |
| x_init = x |
| depth-=1 |
|
|
| x = Reshape((model_scale, 512))(x) |
| model = Model(inputs=inp,outputs=x) |
| model.compile(loss=loss, metrics=[], optimizer=optimizer) |
| return model |
|
|
| def finetune_resnet(model, save_path, model_res=1024, image_size=256, batch_size=10000, test_size=1000, n_epochs=10, max_patience=5, seed=0, minibatch_size=32, truncation=0.7): |
| """ |
| Finetunes a resnet to predict W from X |
| Generate batches (X, W) of size 'batch_size', iterates 'n_epochs', and repeat while 'max_patience' is reached |
| on the test set. The model is saved every time a new best test loss is reached. |
| """ |
| assert image_size >= 224 |
|
|
| |
| print('Creating test set:') |
| np.random.seed(seed) |
| W_test, X_test = generate_dataset(n=test_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation) |
|
|
| |
| print('Generating training set:') |
| patience = 0 |
| best_loss = np.inf |
| |
| |
| while (patience <= max_patience): |
| W_train = X_train = None |
| W_train, X_train = generate_dataset(batch_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation) |
| model.fit(X_train, W_train, epochs=n_epochs, verbose=True, batch_size=minibatch_size) |
| loss = model.evaluate(X_test, W_test, batch_size=minibatch_size) |
| if loss < best_loss: |
| print('New best test loss : {:.5f}'.format(loss)) |
| patience = 0 |
| best_loss = loss |
| else: |
| print('Test loss : {:.5f}'.format(loss)) |
| patience += 1 |
| if (patience > max_patience): |
| print('Done with current test set.') |
| model.fit(X_test, W_test, epochs=n_epochs, verbose=True, batch_size=minibatch_size) |
| print('Saving model.') |
| model.save(save_path) |
|
|
| parser = argparse.ArgumentParser(description='Train a ResNet to predict latent representations of images in a StyleGAN model from generated examples', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| parser.add_argument('--model_url', default='karras2019stylegan-ffhq-1024x1024.pkl', help='Fetch a StyleGAN model to train on from this URL') |
| parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int) |
| parser.add_argument('--data_dir', default='data', help='Directory for storing the ResNet model') |
| parser.add_argument('--model_path', default='data/finetuned_resnet.h5', help='Save / load / create the ResNet model with this file path') |
| parser.add_argument('--model_depth', default=1, help='Number of TreeConnect layers to add after ResNet', type=int) |
| parser.add_argument('--model_size', default=1, help='Model size - 0 - small, 1 - medium, 2 - large, 3 - full.', type=int) |
| parser.add_argument('--activation', default='elu', help='Activation function to use after ResNet') |
| parser.add_argument('--optimizer', default='adam', help='Optimizer to use') |
| parser.add_argument('--loss', default='logcosh', help='Loss function to use') |
| parser.add_argument('--use_fp16', default=False, help='Use 16-bit floating point', type=bool) |
| parser.add_argument('--image_size', default=256, help='Size of images for ResNet model', type=int) |
| parser.add_argument('--batch_size', default=2048, help='Batch size for training the ResNet model', type=int) |
| parser.add_argument('--test_size', default=512, help='Batch size for testing the ResNet model', type=int) |
| parser.add_argument('--truncation', default=0.7, help='Generate images using truncation trick', type=float) |
| parser.add_argument('--max_patience', default=2, help='Number of iterations to wait while test loss does not improve', type=int) |
| parser.add_argument('--freeze_first', default=False, help='Start training with the pre-trained network frozen, then unfreeze', type=bool) |
| parser.add_argument('--epochs', default=2, help='Number of training epochs to run for each batch', type=int) |
| parser.add_argument('--minibatch_size', default=16, help='Size of minibatches for training and generation', type=int) |
| parser.add_argument('--seed', default=-1, help='Pick a random seed for reproducibility (-1 for no random seed selected)', type=int) |
| parser.add_argument('--loop', default=-1, help='Run this many iterations (-1 for infinite, halt with CTRL-C)', type=int) |
|
|
| args, other_args = parser.parse_known_args() |
|
|
| os.makedirs(args.data_dir, exist_ok=True) |
|
|
| if args.seed == -1: |
| args.seed = None |
|
|
| if args.use_fp16: |
| K.set_floatx('float16') |
| K.set_epsilon(1e-4) |
|
|
| tflib.init_tf() |
|
|
| model = get_resnet_model(args.model_path, model_res=args.model_res, depth=args.model_depth, size=args.model_size, activation=args.activation, optimizer=args.optimizer, loss=args.loss) |
|
|
| with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f: |
| generator_network, discriminator_network, Gs_network = pickle.load(f) |
|
|
| def load_Gs(): |
| return Gs_network |
|
|
| if args.freeze_first: |
| model.layers[1].trainable = False |
| model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer) |
|
|
| model.summary() |
|
|
| if args.freeze_first: |
| finetune_resnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation) |
| model.layers[1].trainable = True |
| model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer) |
| model.summary() |
|
|
| if args.loop < 0: |
| while True: |
| finetune_resnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation) |
| else: |
| count = args.loop |
| while count > 0: |
| finetune_resnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation) |
| count -= 1 |
|
|