|
|
| from utils import download_url |
| import argparse |
| import numpy as np |
| import PIL.Image |
| import dnnlib |
| import dnnlib.tflib as tflib |
| import re |
| import sys |
| from io import BytesIO |
| import IPython.display |
| from math import ceil |
| from PIL import Image, ImageDraw |
| import os |
| import pickle |
| from utils import log_progress, imshow, create_image_grid, show_animation |
| import imageio |
| import glob |
| import gdown |
|
|
| class Rasm: |
|
|
| def __init__(self, mode = 'calligraphy'): |
|
|
| if mode == 'calligraphy': |
| url = 'https://drive.google.com/uc?id=138fdURGxdkOwZq7IWvnrGLcfo5VI8O1R' |
|
|
| else: |
| url = 'https://drive.google.com/uc?id=13h-alXGI0hbNOJy1qbmeoroXZSPBHEG2' |
|
|
| output = 'model.pkl' |
| print('Downloading networks from "%s"...' %url) |
| gdown.download(url, output, quiet=False) |
| dnnlib.tflib.init_tf() |
| with dnnlib.util.open_url(output) as fp: |
| self._G, self._D, self.Gs = pickle.load(fp) |
| self.noise_vars = [var for name, var in self.Gs.components.synthesis.vars.items() if name.startswith('noise')] |
| |
| |
| def generate_images_in_w_space(self, dlatents, truncation_psi): |
| Gs_kwargs = dnnlib.EasyDict() |
| Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
| Gs_kwargs.randomize_noise = False |
| Gs_kwargs.truncation_psi = truncation_psi |
| |
|
|
| imgs = [] |
| for _, dlatent in log_progress(enumerate(dlatents), name = "Generating images"): |
| |
| |
| row_images = self.Gs.components.synthesis.run(dlatent, **Gs_kwargs) |
| imgs.append(PIL.Image.fromarray(row_images[0], 'RGB')) |
| return imgs |
|
|
| def generate_images(self, zs, truncation_psi, class_idx = None): |
| Gs_kwargs = dnnlib.EasyDict() |
| Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
| Gs_kwargs.randomize_noise = False |
| if not isinstance(truncation_psi, list): |
| truncation_psi = [truncation_psi] * len(zs) |
| |
| imgs = [] |
| label = np.zeros([1] + self.Gs.input_shapes[1][1:]) |
| if class_idx is not None: |
| label[:, class_idx] = 1 |
| else: |
| label = None |
| for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"): |
| Gs_kwargs.truncation_psi = truncation_psi[z_idx] |
| noise_rnd = np.random.RandomState(1) |
| tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) |
| images = self.Gs.run(z, label, **Gs_kwargs) |
| imgs.append(PIL.Image.fromarray(images[0], 'RGB')) |
| return imgs |
| |
| def generate_from_zs(self, zs, truncation_psi = 0.5): |
| Gs_kwargs = dnnlib.EasyDict() |
| Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
| Gs_kwargs.randomize_noise = False |
| if not isinstance(truncation_psi, list): |
| truncation_psi = [truncation_psi] * len(zs) |
| |
| for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"): |
| Gs_kwargs.truncation_psi = truncation_psi[z_idx] |
| noise_rnd = np.random.RandomState(1) |
| tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) |
| images = self.Gs.run(z, None, **Gs_kwargs) |
| img = PIL.Image.fromarray(images[0], 'RGB') |
| imshow(img) |
|
|
| def generate_random_zs(self, size): |
| seeds = np.random.randint(2**32, size=size) |
| zs = [] |
| for _, seed in enumerate(seeds): |
| rnd = np.random.RandomState(seed) |
| z = rnd.randn(1, *self.Gs.input_shape[1:]) |
| zs.append(z) |
| return zs |
|
|
|
|
| def generate_zs_from_seeds(self, seeds): |
| zs = [] |
| for _, seed in enumerate(seeds): |
| rnd = np.random.RandomState(seed) |
| z = rnd.randn(1, *self.Gs.input_shape[1:]) |
| zs.append(z) |
| return zs |
|
|
| |
| def generate_images_from_seeds(self, seeds, truncation_psi): |
| ima = self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi)[0] |
| return ima, imshow(ima) |
|
|
| def generate_randomly(self, truncation_psi = 0.5): |
| ima, dis = self.generate_images_from_seeds(np.random.randint(4294967295, size=1), truncation_psi=truncation_psi) |
| return ima, dis |
|
|
| def generate_grid(self, truncation_psi = 0.7): |
| seeds = np.random.randint((2**32 - 1), size=9) |
| return create_image_grid(self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi), 0.7 , 3) |
| |
| def generate_animation(self, size = 9, steps = 10, trunc_psi = 0.5): |
| seeds = list(np.random.randint((2**32) - 1, size=size)) |
| seeds = seeds + [seeds[0]] |
| zs = self.generate_zs_from_seeds(seeds) |
|
|
| imgs = self.generate_images(self.interpolate(zs, steps = steps), trunc_psi) |
| movie_name = 'animation.mp4' |
| with imageio.get_writer(movie_name, mode='I') as writer: |
| for image in log_progress(list(imgs), name = "Creating animation"): |
| writer.append_data(np.array(image)) |
| return show_animation(movie_name) |
|
|
| def convertZtoW(self, latent, truncation_psi=0.7, truncation_cutoff=9): |
| dlatent = self.Gs.components.mapping.run(latent, None) |
| dlatent_avg = self.Gs.get_var('dlatent_avg') |
| for i in range(truncation_cutoff): |
| dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg |
| |
| return dlatent |
|
|
| def interpolate(self, zs, steps = 10): |
| out = [] |
| for i in range(len(zs)-1): |
| for index in range(steps): |
| fraction = index/float(steps) |
| out.append(zs[i+1]*fraction + zs[i]*(1-fraction)) |
| return out |
|
|