| import multiprocessing |
| import time |
| import traceback |
|
|
| import cv2 |
| import numpy as np |
| import numpy.linalg as npla |
|
|
| from core import mplib |
| from core import imagelib |
| from core.interact import interact as io |
| from core.joblib import SubprocessGenerator, ThisThreadGenerator |
| from core import mathlib |
| from facelib import LandmarksProcessor, FaceType |
| from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, |
| SampleType) |
|
|
| class SampleGeneratorSAE(SampleGeneratorBase): |
| def __init__ (self, src_samples_path, dst_samples_path, |
| resolution, |
| face_type, |
| random_src_flip=False, |
| random_dst_flip=False, |
| ct_mode=None, |
| uniform_yaw_distribution=False, |
| data_format='NHWC', |
| debug=False, batch_size=1, |
| raise_on_no_data=True, |
| **kwargs): |
|
|
| super().__init__(debug, batch_size) |
| self.initialized = False |
| self.resolution = resolution |
| self.face_type = face_type |
| self.random_src_flip = random_src_flip |
| self.random_dst_flip = random_dst_flip |
| self.ct_mode = ct_mode |
| self.data_format = data_format |
|
|
| if self.debug: |
| self.generators_count = 1 |
| else: |
| self.generators_count = 8 |
|
|
| src_samples = SampleLoader.load (SampleType.FACE, src_samples_path) |
| src_samples_len = len(src_samples) |
|
|
| if src_samples_len == 0: |
| raise ValueError(f'No samples in {src_samples_path}') |
|
|
| dst_samples = SampleLoader.load (SampleType.FACE, dst_samples_path) |
| dst_samples_len = len(dst_samples) |
|
|
| if dst_samples_len == 0: |
| raise ValueError(f'No samples in {dst_samples_path}') |
|
|
| if uniform_yaw_distribution: |
| src_index_host = self._filter_uniform_yaw(src_samples) |
| dst_index_host = self._filter_uniform_yaw(dst_samples) |
| else: |
| src_index_host = mplib.IndexHost(src_samples_len) |
| dst_index_host = mplib.IndexHost(dst_samples_len) |
|
|
| ct_index_host = mplib.IndexHost(dst_samples_len) if ct_mode is not None else None |
| |
| self.comm_qs = [ multiprocessing.Queue() for i in range(self.generators_count) ] |
| |
| if self.debug: |
| self.generators = [ThisThreadGenerator ( self.batch_func, (self.comm_qs[0], src_samples, dst_samples, src_index_host.create_cli(), dst_index_host.create_cli(), ct_index_host.create_cli() if ct_index_host is not None else None) )] |
| else: |
| self.generators = [SubprocessGenerator ( self.batch_func, (self.comm_qs[i], src_samples, dst_samples, src_index_host.create_cli(), dst_index_host.create_cli(), ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ |
| for i in range(self.generators_count) ] |
|
|
| self.generator_counter = -1 |
|
|
| self.initialized = True |
| |
| def start(self): |
| if not self.debug: |
| SubprocessGenerator.start_in_parallel( self.generators ) |
|
|
| def _filter_uniform_yaw(self, samples): |
| samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ] |
|
|
| grads = 128 |
| |
| grads_space = np.linspace (-1.2, 1.2,grads) |
|
|
| yaws_sample_list = [None]*grads |
| for g in io.progress_bar_generator ( range(grads), "Sort by yaw"): |
| yaw = grads_space[g] |
| next_yaw = grads_space[g+1] if g < grads-1 else yaw |
|
|
| yaw_samples = [] |
| for idx, pyr in samples_pyr: |
| s_yaw = -pyr[1] |
| if (g == 0 and s_yaw < next_yaw) or \ |
| (g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \ |
| (g == grads-1 and s_yaw >= yaw): |
| yaw_samples += [ idx ] |
| if len(yaw_samples) > 0: |
| yaws_sample_list[g] = yaw_samples |
|
|
| yaws_sample_list = [ y for y in yaws_sample_list if y is not None ] |
|
|
| return mplib.Index2DHost( yaws_sample_list ) |
| |
| def set_face_scale(self, scale): |
| for comm_q in self.comm_qs: |
| comm_q.put( ('face_scale', scale) ) |
| |
| |
| |
| def is_initialized(self): |
| return self.initialized |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| if not self.initialized: |
| return [] |
|
|
| self.generator_counter += 1 |
| generator = self.generators[self.generator_counter % len(self.generators) ] |
| return next(generator) |
|
|
| def batch_func(self, param ): |
| comm_q, src_samples, dst_samples, src_index_host, dst_index_host, ct_index_host = param |
|
|
| batch_size = self.batch_size |
| resolution = self.resolution |
| face_type = self.face_type |
| data_format = self.data_format |
| random_src_flip = self.random_src_flip |
| random_dst_flip = self.random_dst_flip |
| ct_mode = self.ct_mode |
|
|
| rotation_range=[-10,10] |
| scale_range=[-0.05, 0.05] |
| tx_range=[-0.05, 0.05] |
| ty_range=[-0.05, 0.05] |
| rnd_state = np.random |
|
|
| face_scale = 1.0 |
| |
| hi_res = 1024 |
|
|
| def gen_sample(sample, target_face_type, resolution, allow_flip=False, scale=1.0, ct_mode=None, ct_sample=None): |
| tx = rnd_state.uniform( tx_range[0], tx_range[1] ) |
| ty = rnd_state.uniform( ty_range[0], ty_range[1] ) |
| rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) |
| scale = rnd_state.uniform(scale +scale_range[0], scale +scale_range[1]) |
| |
| flip = allow_flip and rnd_state.randint(10) < 4 |
|
|
| face_type = sample.face_type |
| face_lmrks = sample.landmarks |
| face = sample.load_bgr() |
| h,w,c = face.shape |
|
|
| if face_type == FaceType.HEAD: |
| hi_mat = LandmarksProcessor.get_transform_mat (face_lmrks, hi_res, FaceType.HEAD) |
| else: |
| hi_mat = LandmarksProcessor.get_transform_mat (face_lmrks, hi_res, FaceType.HEAD_FACE) |
|
|
| hi_lmrks = LandmarksProcessor.transform_points(face_lmrks, hi_mat) |
| hi_warp_params = imagelib.gen_warp_params(hi_res) |
| face_warp_params = imagelib.gen_warp_params(resolution) |
|
|
| hi_to_target_mat = LandmarksProcessor.get_transform_mat (hi_lmrks, resolution, target_face_type) |
| hi_to_target_mat = mathlib.transform_mat(hi_to_target_mat, resolution, tx, ty, rotation, scale) |
|
|
| face_to_target_mat = LandmarksProcessor.get_transform_mat (face_lmrks, resolution, target_face_type) |
| face_to_target_mat = mathlib.transform_mat(face_to_target_mat, resolution, tx, ty, rotation, scale) |
|
|
| warped_face = face |
| if ct_mode is not None: |
| ct_bgr = ct_sample.load_bgr() |
| ct_bgr = cv2.resize(ct_bgr, (w,h), interpolation=cv2.INTER_LINEAR ) |
| warped_face = imagelib.color_transfer (ct_mode, warped_face, ct_bgr) |
|
|
| warped_face = cv2.warpAffine(warped_face, hi_mat, (hi_res,hi_res), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) |
| warped_face = np.clip( imagelib.warp_by_params (hi_warp_params, warped_face, can_warp=True, can_transform=False, can_flip=False, border_replicate=cv2.BORDER_REPLICATE), 0, 1) |
| warped_face = cv2.warpAffine(warped_face, hi_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) |
|
|
| """ |
| if face_type != target_face_type: |
| ... |
| else: |
| if w != resolution: |
| face = cv2.resize(face, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) |
| """ |
| |
| |
| |
|
|
| target_face = face |
| if ct_mode is not None: |
| target_face = imagelib.color_transfer (ct_mode, target_face, ct_bgr) |
|
|
| target_face = cv2.warpAffine(target_face, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) |
|
|
|
|
| face_mask = sample.get_xseg_mask() |
| if face_mask is not None: |
| if face_mask.shape[0] != h or face_mask.shape[1] != w: |
| face_mask = cv2.resize(face_mask, (w,h), interpolation=cv2.INTER_CUBIC) |
| face_mask = imagelib.normalize_channels(face_mask, 1) |
| else: |
| face_mask = LandmarksProcessor.get_image_hull_mask (face.shape, face_lmrks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) |
| face_mask = np.clip(face_mask, 0, 1) |
|
|
| target_face_mask = cv2.warpAffine(face_mask, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LINEAR ) |
| target_face_mask = imagelib.normalize_channels(target_face_mask, 1) |
| target_face_mask = np.clip(target_face_mask, 0, 1) |
| |
| em_mask = np.clip(LandmarksProcessor.get_image_eye_mask (face.shape, face_lmrks) + \ |
| LandmarksProcessor.get_image_mouth_mask (face.shape, face_lmrks), 0, 1) |
|
|
| target_face_em = cv2.warpAffine(em_mask, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LINEAR ) |
| target_face_em = imagelib.normalize_channels(target_face_em, 1) |
| |
| div = target_face_em.max() |
| if div != 0.0: |
| target_face_em = target_face_em / div |
| |
| target_face_em = target_face_em * target_face_mask |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| if flip: |
| warped_face = warped_face[:,::-1,...] |
| target_face = target_face[:,::-1,...] |
| target_face_mask = target_face_mask[:,::-1,...] |
| target_face_em = target_face_em[:,::-1,...] |
|
|
| return warped_face, target_face, target_face_mask, target_face_em |
|
|
| |
| while True: |
| while not comm_q.empty(): |
| cmd, param = comm_q.get() |
| if cmd == 'face_scale': |
| face_scale = param |
| |
| batches = [ [], [], [], [], [], [] ,[] ,[] ] |
|
|
| src_indexes = src_index_host.multi_get(batch_size) |
| dst_indexes = dst_index_host.multi_get(batch_size) |
|
|
| for n_batch in range(batch_size): |
| src_sample = src_samples[src_indexes[n_batch]] |
| dst_sample = dst_samples[dst_indexes[n_batch]] |
|
|
| src_warped_face, src_target_face, src_target_face_mask, src_target_face_em = \ |
| gen_sample(src_sample, face_type, resolution, allow_flip=random_src_flip, scale=face_scale, ct_mode=ct_mode, ct_sample=dst_sample) |
|
|
| dst_warped_face, dst_target_face, dst_target_face_mask, dst_target_face_em = \ |
| gen_sample(dst_sample, face_type, resolution, allow_flip=random_dst_flip, scale=face_scale) |
|
|
|
|
|
|
| if data_format == "NCHW": |
| src_warped_face = np.transpose(src_warped_face, (2,0,1) ) |
| src_target_face = np.transpose(src_target_face, (2,0,1) ) |
| src_target_face_mask = np.transpose(src_target_face_mask, (2,0,1) ) |
| src_target_face_em = np.transpose(src_target_face_em, (2,0,1) ) |
| dst_warped_face = np.transpose(dst_warped_face, (2,0,1) ) |
| dst_target_face = np.transpose(dst_target_face, (2,0,1) ) |
| dst_target_face_mask = np.transpose(dst_target_face_mask, (2,0,1) ) |
| dst_target_face_em = np.transpose(dst_target_face_em, (2,0,1) ) |
|
|
| batches[0].append(src_warped_face) |
| batches[1].append(src_target_face) |
| batches[2].append(src_target_face_mask) |
| batches[3].append(src_target_face_em) |
| batches[4].append(dst_warped_face) |
| batches[5].append(dst_target_face) |
| batches[6].append(dst_target_face_mask) |
| batches[7].append(dst_target_face_em) |
|
|
|
|
| yield [ np.array(batch) for batch in batches] |
|
|