| from ugatit.ops import * |
| from ugatit.utils import * |
| from glob import glob |
| import time |
| from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch |
| import numpy as np |
| from ugatit.utils import * |
|
|
| class UgatitTest: |
|
|
| def __init__(self, sess, checkpoint_dir): |
| self.light = False |
|
|
| if self.light: |
| self.model_name = 'UGATIT_light' |
| else: |
| self.model_name = 'UGATIT' |
|
|
| self.sess = sess |
| self.phase = 'test' |
| self.checkpoint_dir = checkpoint_dir |
| self.result_dir = 'results' |
| self.log_dir = 'logs' |
| self.dataset_name = 'selfie2anime' |
| self.augment_flag = True |
|
|
| self.epoch = 100 |
| self.iteration = 10000 |
| self.decay_flag = True |
| self.decay_epoch = 50 |
|
|
| self.gan_type = 'lsgan' |
|
|
| self.batch_size = 1 |
| self.print_freq = 1000 |
| self.save_freq = 1000 |
|
|
| self.init_lr = 0.0001 |
| self.ch = 64 |
|
|
| """ Weight """ |
| self.adv_weight = 1 |
| self.cycle_weight = 10 |
| self.identity_weight = 10 |
| self.cam_weight = 1000 |
| self.ld = 10 |
| self.smoothing = True |
|
|
| """ Generator """ |
| self.n_res = 4 |
|
|
| """ Discriminator """ |
| self.n_dis = 6 |
| self.n_critic = 1 |
| self.sn = True |
|
|
| self.img_size = 256 |
| self.img_ch = 3 |
|
|
| |
| |
|
|
| |
| self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) |
| self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) |
| self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) |
|
|
| print() |
|
|
| print("##### Information #####") |
| print("# light : ", self.light) |
| print("# gan type : ", self.gan_type) |
| print("# dataset : ", self.dataset_name) |
| print("# max dataset number : ", self.dataset_num) |
| print("# batch_size : ", self.batch_size) |
| print("# epoch : ", self.epoch) |
| print("# iteration per epoch : ", self.iteration) |
| print("# smoothing : ", self.smoothing) |
|
|
| print() |
|
|
| print("##### Generator #####") |
| print("# residual blocks : ", self.n_res) |
|
|
| print() |
|
|
| print("##### Discriminator #####") |
| print("# discriminator layer : ", self.n_dis) |
| print("# the number of critic : ", self.n_critic) |
| print("# spectral normalization : ", self.sn) |
|
|
| print() |
|
|
| print("##### Weight #####") |
| print("# adv_weight : ", self.adv_weight) |
| print("# cycle_weight : ", self.cycle_weight) |
| print("# identity_weight : ", self.identity_weight) |
| print("# cam_weight : ", self.cam_weight) |
|
|
| |
| |
| |
|
|
| def generator(self, x_init, reuse=False, scope="generator"): |
| channel = self.ch |
| with tf.variable_scope(scope, reuse=reuse) : |
| x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv') |
| x = instance_norm(x, scope='ins_norm') |
| x = relu(x) |
|
|
| |
| for i in range(2) : |
| x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i)) |
| x = instance_norm(x, scope='ins_norm_'+str(i)) |
| x = relu(x) |
|
|
| channel = channel * 2 |
|
|
| |
| for i in range(self.n_res): |
| x = resblock(x, channel, scope='resblock_' + str(i)) |
|
|
|
|
| |
| cam_x = global_avg_pooling(x) |
| cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit') |
| x_gap = tf.multiply(x, cam_x_weight) |
|
|
| cam_x = global_max_pooling(x) |
| cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit') |
| x_gmp = tf.multiply(x, cam_x_weight) |
|
|
|
|
| cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) |
| x = tf.concat([x_gap, x_gmp], axis=-1) |
|
|
| x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') |
| x = relu(x) |
|
|
| heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) |
|
|
| |
| gamma, beta = self.MLP(x, reuse=reuse) |
|
|
| |
| for i in range(self.n_res): |
| x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i)) |
|
|
| |
| for i in range(2) : |
| x = up_sample(x, scale_factor=2) |
| x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i)) |
| x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i)) |
| x = relu(x) |
|
|
| channel = channel // 2 |
|
|
|
|
| x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit') |
| x = tanh(x) |
|
|
| return x, cam_logit, heatmap |
|
|
| def MLP(self, x, use_bias=True, reuse=False, scope='MLP'): |
| channel = self.ch * self.n_res |
|
|
| if self.light : |
| x = global_avg_pooling(x) |
|
|
| with tf.variable_scope(scope, reuse=reuse): |
| for i in range(2) : |
| x = fully_connected(x, channel, use_bias, scope='linear_' + str(i)) |
| x = relu(x) |
|
|
|
|
| gamma = fully_connected(x, channel, use_bias, scope='gamma') |
| beta = fully_connected(x, channel, use_bias, scope='beta') |
|
|
| gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel]) |
| beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel]) |
|
|
| return gamma, beta |
|
|
| |
| |
| |
|
|
| def discriminator(self, x_init, reuse=False, scope="discriminator"): |
| D_logit = [] |
| D_CAM_logit = [] |
| with tf.variable_scope(scope, reuse=reuse) : |
| local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local') |
| global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global') |
|
|
| D_logit.extend([local_x, global_x]) |
| D_CAM_logit.extend([local_cam, global_cam]) |
|
|
| return D_logit, D_CAM_logit, local_heatmap, global_heatmap |
|
|
| def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'): |
| with tf.variable_scope(scope, reuse=reuse): |
| channel = self.ch |
| x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') |
| x = lrelu(x, 0.2) |
|
|
| for i in range(1, self.n_dis - 1): |
| x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) |
| x = lrelu(x, 0.2) |
|
|
| channel = channel * 2 |
|
|
| x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') |
| x = lrelu(x, 0.2) |
|
|
| channel = channel * 2 |
|
|
| cam_x = global_avg_pooling(x) |
| cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') |
| x_gap = tf.multiply(x, cam_x_weight) |
|
|
| cam_x = global_max_pooling(x) |
| cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') |
| x_gmp = tf.multiply(x, cam_x_weight) |
|
|
| cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) |
| x = tf.concat([x_gap, x_gmp], axis=-1) |
|
|
| x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') |
| x = lrelu(x, 0.2) |
|
|
| heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) |
|
|
|
|
| x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') |
|
|
| return x, cam_logit, heatmap |
|
|
| def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'): |
| with tf.variable_scope(scope, reuse=reuse) : |
| channel = self.ch |
| x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') |
| x = lrelu(x, 0.2) |
|
|
| for i in range(1, self.n_dis - 2 - 1): |
| x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) |
| x = lrelu(x, 0.2) |
|
|
| channel = channel * 2 |
|
|
| x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') |
| x = lrelu(x, 0.2) |
|
|
| channel = channel * 2 |
|
|
| cam_x = global_avg_pooling(x) |
| cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') |
| x_gap = tf.multiply(x, cam_x_weight) |
|
|
| cam_x = global_max_pooling(x) |
| cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') |
| x_gmp = tf.multiply(x, cam_x_weight) |
|
|
| cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) |
| x = tf.concat([x_gap, x_gmp], axis=-1) |
|
|
| x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') |
| x = lrelu(x, 0.2) |
|
|
| heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) |
|
|
| x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') |
|
|
| return x, cam_logit, heatmap |
|
|
| def generate_a2b(self, x_A, reuse=False): |
| out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B") |
|
|
| return out, cam |
|
|
| def generate_b2a(self, x_B, reuse=False): |
| out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A") |
|
|
| return out, cam |
| def build_model(self): |
| self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') |
| self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') |
|
|
| self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) |
| self.test_fake_A, _ = self.generate_b2a(self.test_domain_B) |
|
|
| @property |
| def model_dir(self): |
| n_res = str(self.n_res) + 'resblock' |
| n_dis = str(self.n_dis) + 'dis' |
|
|
| if self.smoothing: |
| smoothing = '_smoothing' |
| else: |
| smoothing = '' |
|
|
| if self.sn: |
| sn = '_sn' |
| else: |
| sn = '' |
|
|
| return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name, |
| self.gan_type, n_res, n_dis, |
| self.n_critic, |
| self.adv_weight, self.cycle_weight, self.identity_weight, |
| self.cam_weight, sn, smoothing) |
|
|
| def load(self, checkpoint_dir): |
| print(" [*] Reading checkpoints...") |
| checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) |
|
|
| ckpt = tf.train.get_checkpoint_state(checkpoint_dir) |
| if ckpt and ckpt.model_checkpoint_path: |
| ckpt_name = os.path.basename(ckpt.model_checkpoint_path) |
| self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) |
| counter = int(ckpt_name.split('-')[-1]) |
| print(" [*] Success to read {}".format(ckpt_name)) |
| return True, counter |
| else: |
| print(" [*] Failed to find a checkpoint") |
| return False, 0 |
|
|
| def loadModel(self): |
| tf.global_variables_initializer().run(session=self.sess) |
|
|
| self.saver = tf.train.Saver() |
| could_load, checkpoint_counter = self.load(self.checkpoint_dir) |
| self.result_dir = os.path.join(self.result_dir, self.model_dir) |
| check_folder(self.result_dir) |
|
|
| if could_load: |
| print(" [*] Load SUCCESS") |
| else: |
| print(" [!] Load failed...") |
|
|
| def test(self, sample_file): |
| |
| print('Processing A image: ' + sample_file) |
| sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) |
| image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) |
|
|
| fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image}) |
| save_images(fake_img, [1, 1], image_path) |
|
|
| return image_path |
|
|
|
|
| gan = None |
| def main_test(img_path, checkpoint_dir): |
| |
| sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) |
| global gan |
| if gan is None: |
| gan = UgatitTest(sess, checkpoint_dir) |
| |
| gan.build_model() |
| |
| show_all_variables() |
|
|
| gan.loadModel() |
|
|
| result = gan.test(img_path) |
| print(" [*] Test finished!") |
| print(result) |
| return os.path.abspath(result) |
|
|
| if __name__ == '__main__': |
| main_test('/home/hylee/cartoon/myp2c/imgs/src/im4.jpg') |