| from __future__ import print_function |
| import torch.backends.cudnn as cudnn |
| import torch |
| import torchvision.transforms as transforms |
|
|
| import argparse |
| import os |
| import random |
| import sys |
| import pprint |
| import datetime |
| import dateutil |
| import dateutil.tz |
|
|
|
|
| dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) |
| sys.path.append(dir_path) |
|
|
| from miscc.datasets import TextDataset |
| from miscc.config import cfg, cfg_from_file |
| from miscc.utils import mkdir_p |
| from trainer import GANTrainer |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Train a GAN network') |
| parser.add_argument('--cfg', dest='cfg_file', |
| help='optional config file', |
| default='birds_stage1.yml', type=str) |
| parser.add_argument('--gpu', dest='gpu_id', type=str, default='0') |
| parser.add_argument('--data_dir', dest='data_dir', type=str, default='') |
| parser.add_argument('--manualSeed', type=int, help='manual seed') |
| args = parser.parse_args() |
| return args |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| if args.cfg_file is not None: |
| cfg_from_file(args.cfg_file) |
| if args.gpu_id != -1: |
| cfg.GPU_ID = args.gpu_id |
| if args.data_dir != '': |
| cfg.DATA_DIR = args.data_dir |
| print('Using config:') |
| pprint.pprint(cfg) |
| if args.manualSeed is None: |
| args.manualSeed = random.randint(1, 10000) |
| random.seed(args.manualSeed) |
| torch.manual_seed(args.manualSeed) |
| if cfg.CUDA: |
| torch.cuda.manual_seed_all(args.manualSeed) |
| now = datetime.datetime.now(dateutil.tz.tzlocal()) |
| timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') |
| output_dir = '../output/%s_%s_%s' % \ |
| (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) |
|
|
| num_gpu = len(cfg.GPU_ID.split(',')) |
| if cfg.TRAIN.FLAG: |
| dataset = TextDataset(cfg.DATA_DIR, 'train', |
| rirsize=cfg.RIRSIZE) |
| assert dataset |
| |
| dataloader = torch.utils.data.DataLoader( |
| dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, |
| drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) |
|
|
| algo = GANTrainer(output_dir) |
| algo.train(dataloader, cfg.STAGE) |
| else: |
| file_path = cfg.EVAL_DIR |
| algo = GANTrainer(output_dir) |
| algo.sample(file_path, cfg.STAGE) |
|
|