| import glob |
| import os |
| from configs import global_config, paths_config, hyperparameters |
| from scripts.latent_creators.sg2_plus_latent_creator import SG2PlusLatentCreator |
| from scripts.latent_creators.e4e_latent_creator import E4ELatentCreator |
| from scripts.run_pti import run_PTI |
| import pickle |
| import torch |
| from utils.models_utils import toogle_grad, load_old_G |
|
|
|
|
| class ExperimentRunner: |
|
|
| def __init__(self, run_id=''): |
| self.images_paths = glob.glob(f'{paths_config.input_data_path}/*') |
| self.target_paths = glob.glob(f'{paths_config.input_data_path}/*') |
| self.run_id = run_id |
| self.sampled_ws = None |
|
|
| self.old_G = load_old_G() |
|
|
| toogle_grad(self.old_G, False) |
|
|
| def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False): |
| if run_pt: |
| self.run_id = run_PTI(self.run_id, use_wandb=use_wandb, use_multi_id_training=use_multi_id_training) |
| if create_other_latents: |
| sg2_plus_latent_creator = SG2PlusLatentCreator(use_wandb=use_wandb) |
| sg2_plus_latent_creator.create_latents() |
| e4e_latent_creator = E4ELatentCreator(use_wandb=use_wandb) |
| e4e_latent_creator.create_latents() |
|
|
| torch.cuda.empty_cache() |
|
|
| return self.run_id |
|
|
|
|
| if __name__ == '__main__': |
| os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' |
| os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices |
|
|
| runner = ExperimentRunner() |
| runner.run_experiment(True, False, False) |
|
|