| import matplotlib |
|
|
| matplotlib.use('Agg') |
|
|
| import os, sys |
| import yaml |
| from argparse import ArgumentParser |
| from time import gmtime, strftime |
| from shutil import copy |
|
|
| from frames_dataset import FramesDataset |
|
|
| from modules.generator import OcclusionAwareGenerator |
| from modules.discriminator import MultiScaleDiscriminator |
| from modules.keypoint_detector import KPDetector |
|
|
| import torch |
|
|
| from train import train |
| from reconstruction import reconstruction |
| from animate import animate |
|
|
| if __name__ == "__main__": |
| |
| if sys.version_info[0] < 3: |
| raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") |
|
|
| parser = ArgumentParser() |
| parser.add_argument("--config", required=True, help="path to config") |
| parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"]) |
| parser.add_argument("--log_dir", default='log', help="path to log into") |
| parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") |
| parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), |
| help="Names of the devices comma separated.") |
| parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") |
| parser.set_defaults(verbose=False) |
|
|
| opt = parser.parse_args() |
| with open(opt.config) as f: |
| config = yaml.load(f) |
|
|
| if opt.checkpoint is not None: |
| log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) |
| else: |
| log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) |
| log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) |
|
|
| generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], |
| **config['model_params']['common_params']) |
|
|
| if torch.cuda.is_available(): |
| generator.to(opt.device_ids[0]) |
| if opt.verbose: |
| print(generator) |
|
|
| discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'], |
| **config['model_params']['common_params']) |
| if torch.cuda.is_available(): |
| discriminator.to(opt.device_ids[0]) |
| if opt.verbose: |
| print(discriminator) |
|
|
| kp_detector = KPDetector(**config['model_params']['kp_detector_params'], |
| **config['model_params']['common_params']) |
|
|
| if torch.cuda.is_available(): |
| kp_detector.to(opt.device_ids[0]) |
|
|
| if opt.verbose: |
| print(kp_detector) |
|
|
| dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params']) |
|
|
| if not os.path.exists(log_dir): |
| os.makedirs(log_dir) |
| if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): |
| copy(opt.config, log_dir) |
|
|
| if opt.mode == 'train': |
| print("Training...") |
| train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids) |
| elif opt.mode == 'reconstruction': |
| print("Reconstruction...") |
| reconstruction(config, generator, kp_detector, opt.checkpoint, log_dir, dataset) |
| elif opt.mode == 'animate': |
| print("Animate...") |
| animate(config, generator, kp_detector, opt.checkpoint, log_dir, dataset) |
|
|