| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| """ Utilities file |
| This file contains utility functions for bookkeeping, logging, and data loading. |
| Methods which directly affect training should either go in layers, the model, |
| or train_fns.py. |
| """ |
|
|
| from __future__ import print_function |
| import sys |
| import os |
| import numpy as np |
| import time |
| import datetime |
| import json |
| import pickle |
| from argparse import ArgumentParser |
| import random |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torchvision.transforms as transforms |
|
|
|
|
| def prepare_parser(): |
| usage = "Parser for all scripts." |
| parser = ArgumentParser(description=usage) |
|
|
| parser.add_argument( |
| "--json_config", |
| type=str, |
| default="", |
| help="Json config from where to load the configuration parameters.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--resolution", |
| type=int, |
| default=64, |
| help="Resolution to train with " "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--augment", |
| action="store_true", |
| default=False, |
| help="Augment with random crops and flips (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_workers", |
| type=int, |
| default=8, |
| help="Number of dataloader workers; consider using less for HDF5 " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--no_pin_memory", |
| action="store_false", |
| dest="pin_memory", |
| default=True, |
| help="Pin data into memory through dataloader? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--shuffle", |
| action="store_true", |
| default=False, |
| help="Shuffle the data (strongly recommended)? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--load_in_mem", |
| action="store_true", |
| default=False, |
| help="Load all data into memory? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--use_multiepoch_sampler", |
| action="store_true", |
| default=False, |
| help="Use the multi-epoch sampler for dataloader? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--use_checkpointable_sampler", |
| action="store_true", |
| default=False, |
| help="Use the checkpointable sampler for dataloader? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--use_balanced_sampler", |
| action="store_true", |
| default=False, |
| help="Use the class balanced sampler for dataloader? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--longtail_temperature", |
| type=int, |
| default=1, |
| help="Temperature to relax longtail_distribution", |
| ) |
|
|
| parser.add_argument( |
| "--longtail", |
| action="store_true", |
| default=False, |
| help="Use long-tail version of the dataset", |
| ) |
| parser.add_argument( |
| "--longtail_gen", |
| action="store_true", |
| default=False, |
| help="Use long-tail version of class conditioning sampling for generator.", |
| ) |
| parser.add_argument( |
| "--custom_distrib_gen", |
| action="store_true", |
| default=False, |
| help="Use custom distribution for sampling class conditionings in generator.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--DiffAugment", type=str, default="", help="DiffAugment policy" |
| ) |
| parser.add_argument( |
| "--DA", |
| action="store_true", |
| default=False, |
| help="Diff Augment for GANs (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--hflips", |
| action="store_true", |
| default=False, |
| help="Use horizontal flips in data augmentation." "(default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--instance_cond", |
| action="store_true", |
| default=False, |
| help="Use instance features as conditioning", |
| ) |
| parser.add_argument( |
| "--feature_augmentation", |
| action="store_true", |
| default=False, |
| help="use hflips in instance conditionings (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--which_knn_balance", |
| type=str, |
| default="instance_balance", |
| choices=["instance_balance", "nnclass_balance"], |
| help="Class balancing either done at the instance level or at the class level.", |
| ) |
| parser.add_argument( |
| "--G_shared_feat", |
| action="store_true", |
| default=False, |
| help="Use fully connected layer for conditioning instance features in G? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--shared_dim_feat", |
| type=int, |
| default=2048, |
| help="G" |
| "s fully connected layer output dimensionality for instance features" |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--k_nn", |
| type=int, |
| default=50, |
| help="Number of neigbors for each instance" "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--feature_extractor", |
| type=str, |
| default="classification", |
| choices=["classification", "selfsupervised"], |
| help="Choice of feature extractor", |
| ) |
| parser.add_argument( |
| "--backbone_feature_extractor", |
| type=str, |
| default="resnet50", |
| choices=["resnet50"], |
| help="Choice of feature extractor backbone", |
| ) |
|
|
| parser.add_argument( |
| "--eval_instance_set", |
| type=str, |
| default="train", |
| help="(Eval) Dataset split from which to draw conditioning instances (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--kmeans_subsampled", |
| type=int, |
| default=-1, |
| help="Number of kmeans centers if using subsampled training instances (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--n_subsampled_data", |
| type=float, |
| default=-1, |
| help="Percent of instances used at test time", |
| ) |
|
|
| |
| parser.add_argument( |
| "--filter_hd", |
| type=int, |
| default=-1, |
| help="Hamming distance to filter val test in COCO_Stuff (by default no filtering) (default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="BigGAN", |
| help="Name of the model module (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_param", |
| type=str, |
| default="SN", |
| help="Parameterization style to use for G, spectral norm (SN) or SVD (SVD)" |
| " or None (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_param", |
| type=str, |
| default="SN", |
| help="Parameterization style to use for D, spectral norm (SN) or SVD (SVD)" |
| " or None (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_ch", |
| type=int, |
| default=64, |
| help="Channel multiplier for G (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_ch", |
| type=int, |
| default=64, |
| help="Channel multiplier for D (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_depth", |
| type=int, |
| default=1, |
| help="Number of resblocks per stage in G? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_depth", |
| type=int, |
| default=1, |
| help="Number of resblocks per stage in D? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_thin", |
| action="store_false", |
| dest="D_wide", |
| default=True, |
| help="Use the SN-GAN channel pattern for D? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_shared", |
| action="store_true", |
| default=True, |
| help="Use shared embeddings in G? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--shared_dim", |
| type=int, |
| default=0, |
| help="G" |
| "s shared embedding dimensionality; if 0, will be equal to dim_z. " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--dim_z", type=int, default=120, help="Noise dimensionality: %(default)s)" |
| ) |
| parser.add_argument( |
| "--z_var", type=float, default=1.0, help="Noise variance: %(default)s)" |
| ) |
| parser.add_argument( |
| "--hier", |
| action="store_true", |
| default=False, |
| help="Use hierarchical z in G? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--syncbn", |
| action="store_true", |
| default=False, |
| help="Sync batch norm? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--cross_replica", |
| action="store_true", |
| default=False, |
| help="Cross_replica batchnorm in G?(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--mybn", |
| action="store_true", |
| default=False, |
| help="Use my batchnorm (which supports standing stats?) %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_nl", |
| type=str, |
| default="relu", |
| help="Activation function for G (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_nl", |
| type=str, |
| default="relu", |
| help="Activation function for D (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_attn", |
| type=str, |
| default="64", |
| help="What resolutions to use attention on for G (underscore separated) " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_attn", |
| type=str, |
| default="64", |
| help="What resolutions to use attention on for D (underscore separated) " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--norm_style", |
| type=str, |
| default="bn", |
| help="Normalizer style for G, one of bn [batchnorm], in [instancenorm], " |
| "ln [layernorm], gn [groupnorm] (default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--seed", |
| type=int, |
| default=0, |
| help="Random seed to use; affects both initialization and " |
| " dataloading. (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_init", |
| type=str, |
| default="ortho", |
| help="Init style to use for G (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_init", |
| type=str, |
| default="ortho", |
| help="Init style to use for D(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--skip_init", |
| action="store_true", |
| default=False, |
| help="Skip initialization, ideal for testing when ortho init was used " |
| "(default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--G_lr", |
| type=float, |
| default=5e-5, |
| help="Learning rate to use for Generator (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_lr", |
| type=float, |
| default=2e-4, |
| help="Learning rate to use for Discriminator (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_B1", |
| type=float, |
| default=0.0, |
| help="Beta1 to use for Generator (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_B1", |
| type=float, |
| default=0.0, |
| help="Beta1 to use for Discriminator (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_B2", |
| type=float, |
| default=0.999, |
| help="Beta2 to use for Generator (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_B2", |
| type=float, |
| default=0.999, |
| help="Beta2 to use for Discriminator (default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=64, |
| help="Default overall batchsize (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_batch_size", |
| type=int, |
| default=0, |
| help="Batch size to use for G; if 0, same as D (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_G_accumulations", |
| type=int, |
| default=1, |
| help="Number of passes to accumulate G" |
| "s gradients over " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_D_steps", |
| type=int, |
| default=2, |
| help="Number of D steps per G step (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_D_accumulations", |
| type=int, |
| default=1, |
| help="Number of passes to accumulate D" |
| "s gradients over " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--split_D", |
| action="store_true", |
| default=False, |
| help="Run D twice rather than concatenating inputs? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_epochs", |
| type=int, |
| default=100, |
| help="Number of epochs to train for (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--parallel", |
| action="store_true", |
| default=False, |
| help="Train with multiple GPUs (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_fp16", |
| action="store_true", |
| default=False, |
| help="Train with half-precision in G? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_fp16", |
| action="store_true", |
| default=False, |
| help="Train with half-precision in D? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_mixed_precision", |
| action="store_true", |
| default=False, |
| help="Train with half-precision activations but fp32 params in D? " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--G_mixed_precision", |
| action="store_true", |
| default=False, |
| help="Train with half-precision activations but fp32 params in G? " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--accumulate_stats", |
| action="store_true", |
| default=False, |
| help='Accumulate "standing" batchnorm stats? (default: %(default)s)', |
| ) |
| parser.add_argument( |
| "--num_standing_accumulations", |
| type=int, |
| default=16, |
| help="Number of forward passes to use in accumulating standing stats? " |
| "(default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--slurm_logdir", |
| help="Where to save the logs from SLURM", |
| required=False, |
| default="biggan-training-runs", |
| metavar="DIR", |
| ) |
|
|
| parser.add_argument( |
| "--G_eval_mode", |
| action="store_true", |
| default=False, |
| help="Run G in eval mode (running/standing stats?) at sample/test time? " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--save_every", |
| type=int, |
| default=2000, |
| help="Save every X iterations (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_save_copies", |
| type=int, |
| default=2, |
| help="How many copies to save (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_best_copies", |
| type=int, |
| default=2, |
| help="How many previous best checkpoints to save (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--which_best", |
| type=str, |
| default="IS", |
| help='Which metric to use to determine when to save new "best"' |
| "checkpoints, one of IS or FID (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--no_fid", |
| action="store_true", |
| default=False, |
| help="Calculate IS only, not FID? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--test_every", |
| type=int, |
| default=5000, |
| help="Test every X iterations (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_inception_images", |
| type=int, |
| default=50000, |
| help="Number of samples to compute inception metrics with " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--hashname", |
| action="store_true", |
| default=False, |
| help="Use a hash of the experiment name instead of the full config " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--base_root", |
| type=str, |
| default="", |
| help="Default location to store all weights, samples, data, and logs " |
| " (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--data_root", |
| type=str, |
| default="data", |
| help="Default location where data is stored (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--weights_root", |
| type=str, |
| default="weights", |
| help="Default location to store weights (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--logs_root", |
| type=str, |
| default="logs", |
| help="Default location to store logs (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--samples_root", |
| type=str, |
| default="samples", |
| help="Default location to store samples (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--pbar", |
| type=str, |
| default="mine", |
| help='Type of progressbar to use; one of "mine" or "tqdm" ' |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--name_suffix", |
| type=str, |
| default="", |
| help="Suffix for experiment name for loading weights for sampling " |
| '(consider "best0") (default: %(default)s)', |
| ) |
| parser.add_argument( |
| "--experiment_name", |
| type=str, |
| default="", |
| help="Optionally override the automatic experiment naming with this arg. " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--config_from_name", |
| action="store_true", |
| default=False, |
| help="Use a hash of the experiment name instead of the full config " |
| "(default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--ema", |
| action="store_true", |
| default=False, |
| help="Keep an ema of G" "s weights? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--ema_decay", |
| type=float, |
| default=0.9999, |
| help="EMA decay rate (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--use_ema", |
| action="store_true", |
| default=False, |
| help="Use the EMA parameters of G for evaluation? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--ema_start", |
| type=int, |
| default=20000, |
| help="When to start updating the EMA weights (default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--adam_eps", |
| type=float, |
| default=1e-6, |
| help="epsilon value to use for Adam (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--BN_eps", |
| type=float, |
| default=1e-5, |
| help="epsilon value to use for BatchNorm (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--SN_eps", |
| type=float, |
| default=1e-6, |
| help="epsilon value to use for Spectral Norm(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_G_SVs", |
| type=int, |
| default=1, |
| help="Number of SVs to track in G (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_D_SVs", |
| type=int, |
| default=1, |
| help="Number of SVs to track in D (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_G_SV_itrs", |
| type=int, |
| default=1, |
| help="Number of SV itrs in G (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--num_D_SV_itrs", |
| type=int, |
| default=1, |
| help="Number of SV itrs in D (default: %(default)s)", |
| ) |
|
|
| parser.add_argument( |
| "--class_cond", |
| action="store_true", |
| default=False, |
| help="Use classes as conditioning", |
| ) |
| parser.add_argument( |
| "--constant_conditioning", |
| action="store_true", |
| default=False, |
| help="Use a a class-conditioning vector where the input label is always 0? (default: %(default)s)", |
| ) |
|
|
| parser.add_argument( |
| "--which_dataset", |
| type=str, |
| default="imagenet", |
| |
| help="Dataset choice.", |
| ) |
|
|
| |
| parser.add_argument( |
| "--G_ortho", |
| type=float, |
| default=0.0, |
| help="Modified ortho reg coefficient in G(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--D_ortho", |
| type=float, |
| default=0.0, |
| help="Modified ortho reg coefficient in D (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--toggle_grads", |
| action="store_true", |
| default=True, |
| help="Toggle D and G" |
| 's "requires_grad" settings when not training them? ' |
| " (default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--partition", |
| help="Partition name for SLURM", |
| required=False, |
| default="learnlab", |
| ) |
| parser.add_argument( |
| "--which_train_fn", |
| type=str, |
| default="GAN", |
| help="How2trainyourbois (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--run_setup", |
| type=str, |
| default="slurm", |
| help="If local_debug or slurm (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--ddp_train", |
| action="store_true", |
| default=False, |
| help="If use DDP for training", |
| ) |
| parser.add_argument( |
| "--n_nodes", |
| type=int, |
| default=1, |
| help="Number of nodes for ddp (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--n_gpus_per_node", |
| type=int, |
| default=1, |
| help="Number of gpus per node for ddp (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--stop_when_diverge", |
| action="store_true", |
| default=False, |
| help="Stop the experiment if there is signs of divergence. " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--es_patience", type=int, default=50, help="Epochs for early stopping patience" |
| ) |
| parser.add_argument( |
| "--deterministic_run", |
| action="store_true", |
| default=False, |
| help="Set deterministic cudnn and set the seed at each epoch" |
| "(default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--eval_prdc", |
| action="store_true", |
| default=False, |
| help="(Eval) Evaluate prdc " " (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--eval_reference_set", |
| type=str, |
| default="train", |
| help="(Eval) Reference dataset to use for FID computation (default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--load_weights", |
| type=str, |
| default="", |
| help="Suffix for which weights to load (e.g. best0, copy0) " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--resume", |
| action="store_true", |
| default=False, |
| help="Resume training? (default: %(default)s)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--logstyle", |
| type=str, |
| default="%3.3e", |
| help="What style to use when logging training metrics?" |
| "One of: %#.#f/ %#.#e (float/exp, text)," |
| "pickle (python pickle)," |
| "npz (numpy zip)," |
| "mat (MATLAB .mat file) (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--log_G_spectra", |
| action="store_true", |
| default=False, |
| help="Log the top 3 singular values in each SN layer in G? " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--log_D_spectra", |
| action="store_true", |
| default=False, |
| help="Log the top 3 singular values in each SN layer in D? " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--sv_log_interval", |
| type=int, |
| default=10, |
| help="Iteration interval for logging singular values " |
| " (default: %(default)s)", |
| ) |
|
|
| return parser |
|
|
|
|
| |
| def add_sample_parser(parser): |
| parser.add_argument( |
| "--sample_npz", |
| action="store_true", |
| default=False, |
| help='Sample "sample_num_npz" images and save to npz? ' |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--sample_num_npz", |
| type=int, |
| default=50000, |
| help="Number of images to sample when sampling NPZs " "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--sample_sheets", |
| action="store_true", |
| default=False, |
| help="Produce class-conditional sample sheets and stick them in " |
| "the samples root? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--sample_interps", |
| action="store_true", |
| default=False, |
| help="Produce interpolation sheets and stick them in " |
| "the samples root? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--sample_sheet_folder_num", |
| type=int, |
| default=-1, |
| help="Number to use for the folder for these sample sheets " |
| "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--sample_random", |
| action="store_true", |
| default=False, |
| help="Produce a single random sheet? (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--sample_trunc_curves", |
| type=str, |
| default="", |
| help="Get inception metrics with a range of variances?" |
| "To use this, specify a startpoint, step, and endpoint, e.g. " |
| "--sample_trunc_curves 0.2_0.1_1.0 for a startpoint of 0.2, " |
| "endpoint of 1.0, and stepsize of 1.0. Note that this is " |
| "not exactly identical to using tf.truncated_normal, but should " |
| "have approximately the same effect. (default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--sample_inception_metrics", |
| action="store_true", |
| default=False, |
| help="Calculate Inception metrics with sample.py? (default: %(default)s)", |
| ) |
| return parser |
|
|
|
|
| activation_dict = { |
| "inplace_relu": nn.ReLU(inplace=True), |
| "relu": nn.ReLU(inplace=False), |
| "ir": nn.ReLU(inplace=True), |
| } |
|
|
|
|
| class CenterCropLongEdge(object): |
| """Crops the given PIL Image on the long edge. |
| Args: |
| size (sequence or int): Desired output size of the crop. If size is an |
| int instead of sequence like (h, w), a square crop (size, size) is |
| made. |
| """ |
|
|
| def __call__(self, img): |
| """ |
| Args: |
| img (PIL Image): Image to be cropped. |
| Returns: |
| PIL Image: Cropped image. |
| """ |
| return transforms.functional.center_crop(img, min(img.size)) |
|
|
| def __repr__(self): |
| return self.__class__.__name__ |
|
|
|
|
| class RandomCropLongEdge(object): |
| """Crops the given PIL Image on the long edge with a random start point. |
| Args: |
| size (sequence or int): Desired output size of the crop. If size is an |
| int instead of sequence like (h, w), a square crop (size, size) is |
| made. |
| """ |
|
|
| def __call__(self, img): |
| """ |
| Args: |
| img (PIL Image): Image to be cropped. |
| Returns: |
| PIL Image: Cropped image. |
| """ |
| size = (min(img.size), min(img.size)) |
| |
| i = ( |
| 0 |
| if size[0] == img.size[0] |
| else np.random.randint(low=0, high=img.size[0] - size[0]) |
| ) |
| j = ( |
| 0 |
| if size[1] == img.size[1] |
| else np.random.randint(low=0, high=img.size[1] - size[1]) |
| ) |
| return transforms.functional.crop(img, i, j, size[0], size[1]) |
|
|
| def __repr__(self): |
| return self.__class__.__name__ |
|
|
|
|
| |
| def seed_rng(seed): |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| np.random.seed(seed) |
|
|
|
|
| def seed_worker(worker_id): |
| worker_seed = torch.initial_seed() + worker_id |
|
|
|
|
| |
| |
| def update_config_roots(config, change_weight_folder=True): |
| if config["base_root"]: |
| print("Pegging all root folders to base root %s" % config["base_root"]) |
| for key in ["weights", "logs", "samples"]: |
| if change_weight_folder: |
| config["%s_root" % key] = "%s/%s" % (config["base_root"], key) |
| else: |
| config["%s_root" % key] = "%s" % (config["base_root"]) |
| return config |
|
|
|
|
| |
| def prepare_root(config): |
| for key in ["weights_root", "logs_root", "samples_root"]: |
| if not os.path.exists(config[key]): |
| print("Making directory %s for %s..." % (config[key], key)) |
| os.mkdir(config[key]) |
|
|
|
|
| |
| |
| |
| class ema(object): |
| def __init__(self, source, target, decay=0.9999, start_itr=0): |
| self.source = source |
| self.target = target |
| self.decay = decay |
| |
| self.start_itr = start_itr |
| |
| self.source_dict = self.source.state_dict() |
| self.target_dict = self.target.state_dict() |
| print("Initializing EMA parameters to be source parameters...") |
| with torch.no_grad(): |
| for key in self.source_dict: |
| self.target_dict[key].data.copy_(self.source_dict[key].data) |
| |
|
|
| def update(self, itr=None): |
| |
| |
| if itr and itr < self.start_itr: |
| decay = 0.0 |
| else: |
| decay = self.decay |
| with torch.no_grad(): |
| for key in self.source_dict: |
| self.target_dict[key].data.copy_( |
| self.target_dict[key].data * decay |
| + self.source_dict[key].data * (1 - decay) |
| ) |
|
|
|
|
| |
| |
| |
| def ortho(model, strength=1e-4, blacklist=[]): |
| with torch.no_grad(): |
| for param in model.parameters(): |
| |
| if len(param.shape) < 2 or any([param is item for item in blacklist]): |
| continue |
| w = param.view(param.shape[0], -1) |
| grad = 2 * torch.mm( |
| torch.mm(w, w.t()) * (1.0 - torch.eye(w.shape[0], device=w.device)), w |
| ) |
| param.grad.data += strength * grad.view(param.shape) |
|
|
|
|
| |
| |
| |
| def default_ortho(model, strength=1e-4, blacklist=[]): |
| with torch.no_grad(): |
| for param in model.parameters(): |
| |
| if len(param.shape) < 2 or param in blacklist: |
| continue |
| w = param.view(param.shape[0], -1) |
| grad = 2 * torch.mm( |
| torch.mm(w, w.t()) - torch.eye(w.shape[0], device=w.device), w |
| ) |
| param.grad.data += strength * grad.view(param.shape) |
|
|
|
|
| |
| def toggle_grad(model, on_or_off): |
| for param in model.parameters(): |
| param.requires_grad = on_or_off |
|
|
|
|
| |
| |
| |
| def join_strings(base_string, strings): |
| return base_string.join([item for item in strings if item]) |
|
|
|
|
| |
| def save_weights( |
| G, |
| D, |
| state_dict, |
| weights_root, |
| experiment_name, |
| name_suffix=None, |
| G_ema=None, |
| embedded_optimizers=True, |
| G_optim=None, |
| D_optim=None, |
| ): |
| root = "/".join([weights_root, experiment_name]) |
| if not os.path.exists(root): |
| os.mkdir(root) |
| if name_suffix: |
| print("Saving weights to %s/%s..." % (root, name_suffix)) |
| else: |
| print("Saving weights to %s..." % root) |
| torch.save( |
| G.state_dict(), "%s/%s.pth" % (root, join_strings("_", ["G", name_suffix])) |
| ) |
| torch.save( |
| D.state_dict(), "%s/%s.pth" % (root, join_strings("_", ["D", name_suffix])) |
| ) |
| torch.save( |
| state_dict, "%s/%s.pth" % (root, join_strings("_", ["state_dict", name_suffix])) |
| ) |
|
|
| if embedded_optimizers: |
| torch.save( |
| G.optim.state_dict(), |
| "%s/%s.pth" % (root, join_strings("_", ["G_optim", name_suffix])), |
| ) |
| torch.save( |
| D.optim.state_dict(), |
| "%s/%s.pth" % (root, join_strings("_", ["D_optim", name_suffix])), |
| ) |
| else: |
| torch.save( |
| G_optim.state_dict(), |
| "%s/%s.pth" % (root, join_strings("_", ["G_optim", name_suffix])), |
| ) |
| torch.save( |
| D_optim.state_dict(), |
| "%s/%s.pth" % (root, join_strings("_", ["D_optim", name_suffix])), |
| ) |
| if G_ema is not None: |
| torch.save( |
| G_ema.state_dict(), |
| "%s/%s.pth" % (root, join_strings("_", ["G_ema", name_suffix])), |
| ) |
|
|
|
|
| |
| def load_weights( |
| G, |
| D, |
| state_dict, |
| weights_root, |
| experiment_name, |
| name_suffix=None, |
| G_ema=None, |
| strict=True, |
| load_optim=True, |
| eval=False, |
| map_location=None, |
| embedded_optimizers=True, |
| G_optim=None, |
| D_optim=None, |
| ): |
| root = "/".join([weights_root, experiment_name]) |
| if not os.path.exists(root): |
| print("Not loading data, experiment folder does not exist yet!") |
| print(root) |
| if eval: |
| raise ValueError("Make sure foder exists") |
| return |
|
|
| if name_suffix: |
| print("Loading %s weights from %s..." % (name_suffix, root)) |
| else: |
| print("Loading weights from %s..." % root) |
| if G is not None: |
| G.load_state_dict( |
| torch.load( |
| "%s/%s.pth" % (root, join_strings("_", ["G", name_suffix])), |
| map_location=map_location, |
| ), |
| strict=strict, |
| ) |
| if load_optim: |
| if embedded_optimizers: |
| G.optim.load_state_dict( |
| torch.load( |
| "%s/%s.pth" |
| % (root, join_strings("_", ["G_optim", name_suffix])), |
| map_location=map_location, |
| ) |
| ) |
| else: |
| G_optim.load_state_dict( |
| torch.load( |
| "%s/%s.pth" |
| % (root, join_strings("_", ["G_optim", name_suffix])), |
| map_location=map_location, |
| ) |
| ) |
| if D is not None: |
| D.load_state_dict( |
| torch.load( |
| "%s/%s.pth" % (root, join_strings("_", ["D", name_suffix])), |
| map_location=map_location, |
| ), |
| strict=strict, |
| ) |
| if load_optim: |
| if embedded_optimizers: |
| D.optim.load_state_dict( |
| torch.load( |
| "%s/%s.pth" |
| % (root, join_strings("_", ["D_optim", name_suffix])), |
| map_location=map_location, |
| ) |
| ) |
| else: |
| D_optim.load_state_dict( |
| torch.load( |
| "%s/%s.pth" |
| % (root, join_strings("_", ["D_optim", name_suffix])), |
| map_location=map_location, |
| ) |
| ) |
| |
| for item in state_dict: |
| try: |
| state_dict[item] = torch.load( |
| "%s/%s.pth" % (root, join_strings("_", ["state_dict", name_suffix])), |
| map_location=map_location, |
| )[item] |
| except: |
| print("No values to load") |
| if G_ema is not None: |
| G_ema.load_state_dict( |
| torch.load( |
| "%s/%s.pth" % (root, join_strings("_", ["G_ema", name_suffix])), |
| map_location=map_location, |
| ), |
| strict=strict, |
| ) |
|
|
|
|
| """ MetricsLogger originally stolen from VoxNet source code. |
| Used for logging inception metrics""" |
|
|
|
|
| class MetricsLogger(object): |
| def __init__(self, fname, reinitialize=False): |
| self.fname = fname |
| self.reinitialize = reinitialize |
| if os.path.exists(self.fname): |
| if self.reinitialize: |
| print("{} exists, deleting...".format(self.fname)) |
| os.remove(self.fname) |
|
|
| def log(self, record=None, **kwargs): |
| """ |
| Assumption: no newlines in the input. |
| """ |
| if record is None: |
| record = {} |
| record.update(kwargs) |
| record["_stamp"] = time.time() |
| with open(self.fname, "a") as f: |
| f.write(json.dumps(record, ensure_ascii=True) + "\n") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| class MyLogger(object): |
| def __init__(self, fname, reinitialize=False, logstyle="%3.3f"): |
| self.root = fname |
| if not os.path.exists(self.root): |
| os.mkdir(self.root) |
| self.reinitialize = reinitialize |
| self.metrics = [] |
| self.logstyle = logstyle |
|
|
| |
| def reinit(self, item): |
| if os.path.exists("%s/%s.log" % (self.root, item)): |
| if self.reinitialize: |
| |
| if "sv" in item: |
| if not any("sv" in item for item in self.metrics): |
| print("Deleting singular value logs...") |
| else: |
| print( |
| "{} exists, deleting...".format("%s_%s.log" % (self.root, item)) |
| ) |
| os.remove("%s/%s.log" % (self.root, item)) |
|
|
| |
| def log(self, itr, **kwargs): |
| for arg in kwargs: |
| if arg not in self.metrics: |
| if self.reinitialize: |
| self.reinit(arg) |
| self.metrics += [arg] |
| if self.logstyle == "pickle": |
| print("Pickle not currently supported...") |
| |
| |
| elif self.logstyle == "mat": |
| print(".mat logstyle not currently supported...") |
| else: |
| with open("%s/%s.log" % (self.root, arg), "a") as f: |
| f.write("%d: %s\n" % (itr, self.logstyle % kwargs[arg])) |
|
|
|
|
| |
| def write_metadata(logs_root, experiment_name, config, state_dict): |
| with open(("%s/%s/metalog.txt" % (logs_root, experiment_name)), "w") as writefile: |
| writefile.write("datetime: %s\n" % str(datetime.datetime.now())) |
| writefile.write("config: %s\n" % str(config)) |
| writefile.write("state: %s\n" % str(state_dict)) |
|
|
|
|
| """ |
| Very basic progress indicator to wrap an iterable in. |
| |
| Author: Jan Schlüter |
| Andy's adds: time elapsed in addition to ETA, makes it possible to add |
| estimated time to 1k iters instead of estimated time to completion. |
| """ |
|
|
|
|
| def progress(items, desc="", total=None, min_delay=0.1, displaytype="s1k"): |
| """ |
| Returns a generator over `items`, printing the number and percentage of |
| items processed and the estimated remaining processing time before yielding |
| the next item. `total` gives the total number of items (required if `items` |
| has no length), and `min_delay` gives the minimum time in seconds between |
| subsequent prints. `desc` gives an optional prefix text (end with a space). |
| """ |
| total = total or len(items) |
| t_start = time.time() |
| t_last = 0 |
| for n, item in enumerate(items): |
| t_now = time.time() |
| if t_now - t_last > min_delay: |
| print( |
| "\r%s%d/%d (%6.2f%%)" % (desc, n + 1, total, n / float(total) * 100), |
| end=" ", |
| ) |
| if n > 0: |
|
|
| if displaytype == "s1k": |
| next_1000 = n + (1000 - n % 1000) |
| t_done = t_now - t_start |
| t_1k = t_done / n * next_1000 |
| outlist = list(divmod(t_done, 60)) + list(divmod(t_1k - t_done, 60)) |
| print("(TE/ET1k: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") |
| else: |
| t_done = t_now - t_start |
| t_total = t_done / n * total |
| outlist = list(divmod(t_done, 60)) + list( |
| divmod(t_total - t_done, 60) |
| ) |
| print("(TE/ETA: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") |
|
|
| sys.stdout.flush() |
| t_last = t_now |
| yield item |
| t_total = time.time() - t_start |
| print( |
| "\r%s%d/%d (100.00%%) (took %d:%02d)" |
| % ((desc, total, total) + divmod(t_total, 60)) |
| ) |
|
|
|
|
| |
| def sample( |
| G, |
| sample_conditioning_func, |
| config, |
| class_cond=True, |
| instance_cond=False, |
| device="cuda", |
| ): |
| conditioning = sample_conditioning_func() |
| with torch.no_grad(): |
| if class_cond and not instance_cond: |
| z_, y_ = conditioning |
| y_ = y_.long() |
| y_ = y_.to(device, non_blocking=True) |
| feats_ = None |
| elif instance_cond and not class_cond: |
| z_, feats_ = conditioning |
| feats_ = feats_.to(device, non_blocking=True) |
| y_ = None |
| elif instance_cond and class_cond: |
| z_, y_, feats_ = conditioning |
| y_, feats_ = ( |
| y_.to(device, non_blocking=True), |
| feats_.to(device, non_blocking=True), |
| ) |
| z_ = z_.to(device, non_blocking=True) |
|
|
| if config["parallel"]: |
| G_z = nn.parallel.data_parallel(G, (z_, y_, feats_)) |
| else: |
| G_z = G(z_, y_, feats_) |
| return G_z, y_, feats_ |
|
|
|
|
| |
| def sample_sheet( |
| G, |
| classes_per_sheet, |
| num_classes, |
| samples_per_class, |
| parallel, |
| samples_root, |
| experiment_name, |
| folder_number, |
| z_=None, |
| ): |
| |
| if not os.path.isdir("%s/%s" % (samples_root, experiment_name)): |
| os.mkdir("%s/%s" % (samples_root, experiment_name)) |
| if not os.path.isdir("%s/%s/%d" % (samples_root, experiment_name, folder_number)): |
| os.mkdir("%s/%s/%d" % (samples_root, experiment_name, folder_number)) |
| |
| for i in range(num_classes // classes_per_sheet): |
| ims = [] |
| y = torch.arange( |
| i * classes_per_sheet, (i + 1) * classes_per_sheet, device="cuda" |
| ) |
| for j in range(samples_per_class): |
| if ( |
| (z_ is not None) |
| and hasattr(z_, "sample_") |
| and classes_per_sheet <= z_.size(0) |
| ): |
| z_.sample_() |
| else: |
| z_ = torch.randn(classes_per_sheet, G.dim_z, device="cuda") |
| with torch.no_grad(): |
| if parallel: |
| o = nn.parallel.data_parallel( |
| G, (z_[:classes_per_sheet], G.shared(y)) |
| ) |
| else: |
| o = G(z_[:classes_per_sheet], G.shared(y)) |
|
|
| ims += [o.data.cpu()] |
| |
| out_ims = ( |
| torch.stack(ims, 1) |
| .view(-1, ims[0].shape[1], ims[0].shape[2], ims[0].shape[3]) |
| .data.float() |
| .cpu() |
| ) |
| |
| image_filename = "%s/%s/%d/samples%d.jpg" % ( |
| samples_root, |
| experiment_name, |
| folder_number, |
| i, |
| ) |
| torchvision.utils.save_image( |
| out_ims, image_filename, nrow=samples_per_class, normalize=True |
| ) |
|
|
|
|
| |
| def interp(x0, x1, num_midpoints): |
| lerp = torch.linspace(0, 1.0, num_midpoints + 2, device="cuda").to(x0.dtype) |
| return (x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1)) |
|
|
|
|
| |
| |
| def interp_sheet( |
| G, |
| num_per_sheet, |
| num_midpoints, |
| num_classes, |
| parallel, |
| samples_root, |
| experiment_name, |
| folder_number, |
| sheet_number=0, |
| fix_z=False, |
| fix_y=False, |
| device="cuda", |
| ): |
| |
| if fix_z: |
| zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device) |
| zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z) |
| else: |
| zs = interp( |
| torch.randn(num_per_sheet, 1, G.dim_z, device=device), |
| torch.randn(num_per_sheet, 1, G.dim_z, device=device), |
| num_midpoints, |
| ).view(-1, G.dim_z) |
| if fix_y: |
| ys = sample_1hot(num_per_sheet, num_classes) |
| ys = G.shared(ys).view(num_per_sheet, 1, -1) |
| ys = ys.repeat(1, num_midpoints + 2, 1).view( |
| num_per_sheet * (num_midpoints + 2), -1 |
| ) |
| else: |
| ys = interp( |
| G.shared(sample_1hot(num_per_sheet, num_classes)).view( |
| num_per_sheet, 1, -1 |
| ), |
| G.shared(sample_1hot(num_per_sheet, num_classes)).view( |
| num_per_sheet, 1, -1 |
| ), |
| num_midpoints, |
| ).view(num_per_sheet * (num_midpoints + 2), -1) |
| |
| if G.fp16: |
| zs = zs.half() |
| with torch.no_grad(): |
| if parallel: |
| out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu() |
| else: |
| out_ims = G(zs, ys).data.cpu() |
| interp_style = "" + ("Z" if not fix_z else "") + ("Y" if not fix_y else "") |
| image_filename = "%s/%s/%d/interp%s%d.jpg" % ( |
| samples_root, |
| experiment_name, |
| folder_number, |
| interp_style, |
| sheet_number, |
| ) |
| torchvision.utils.save_image( |
| out_ims, image_filename, nrow=num_midpoints + 2, normalize=True |
| ) |
|
|
|
|
| |
| |
| def print_grad_norms(net): |
| gradsums = [ |
| [ |
| float(torch.norm(param.grad).item()), |
| float(torch.norm(param).item()), |
| param.shape, |
| ] |
| for param in net.parameters() |
| ] |
| order = np.argsort([item[0] for item in gradsums]) |
| print( |
| [ |
| "%3.3e,%3.3e, %s" |
| % ( |
| gradsums[item_index][0], |
| gradsums[item_index][1], |
| str(gradsums[item_index][2]), |
| ) |
| for item_index in order |
| ] |
| ) |
|
|
|
|
| |
| |
| def get_SVs(net, prefix): |
| d = net.state_dict() |
| return { |
| ("%s_%s" % (prefix, key)).replace(".", "_"): float(d[key].item()) |
| for key in d |
| if "sv" in key |
| } |
|
|
|
|
| |
| def name_from_config(config): |
| name = "_".join( |
| [ |
| item |
| for item in [ |
| "Big%s" % config["which_train_fn"], |
| config["dataset"], |
| config["model"] if config["model"] != "BigGAN" else None, |
| "seed%d" % config["seed"], |
| "Gch%d" % config["G_ch"], |
| "Dch%d" % config["D_ch"], |
| "Gd%d" % config["G_depth"] if config["G_depth"] > 1 else None, |
| "Dd%d" % config["D_depth"] if config["D_depth"] > 1 else None, |
| "bs%d" % config["batch_size"], |
| "Gfp16" if config["G_fp16"] else None, |
| "Dfp16" if config["D_fp16"] else None, |
| "nDs%d" % config["num_D_steps"] if config["num_D_steps"] > 1 else None, |
| "nDa%d" % config["num_D_accumulations"] |
| if config["num_D_accumulations"] > 1 |
| else None, |
| "nGa%d" % config["num_G_accumulations"] |
| if config["num_G_accumulations"] > 1 |
| else None, |
| "Glr%2.1e" % config["G_lr"], |
| "Dlr%2.1e" % config["D_lr"], |
| "GB%3.3f" % config["G_B1"] if config["G_B1"] != 0.0 else None, |
| "GBB%3.3f" % config["G_B2"] if config["G_B2"] != 0.999 else None, |
| "DB%3.3f" % config["D_B1"] if config["D_B1"] != 0.0 else None, |
| "DBB%3.3f" % config["D_B2"] if config["D_B2"] != 0.999 else None, |
| "Gnl%s" % config["G_nl"], |
| "Dnl%s" % config["D_nl"], |
| "Ginit%s" % config["G_init"], |
| "Dinit%s" % config["D_init"], |
| "G%s" % config["G_param"] if config["G_param"] != "SN" else None, |
| "D%s" % config["D_param"] if config["D_param"] != "SN" else None, |
| "Gattn%s" % config["G_attn"] if config["G_attn"] != "0" else None, |
| "Dattn%s" % config["D_attn"] if config["D_attn"] != "0" else None, |
| "Gortho%2.1e" % config["G_ortho"] if config["G_ortho"] > 0.0 else None, |
| "Dortho%2.1e" % config["D_ortho"] if config["D_ortho"] > 0.0 else None, |
| config["norm_style"] if config["norm_style"] != "bn" else None, |
| "cr" if config["cross_replica"] else None, |
| "Gshared" if config["G_shared"] else None, |
| "hier" if config["hier"] else None, |
| "ema" if config["ema"] else None, |
| config["name_suffix"] if config["name_suffix"] else None, |
| ] |
| if item is not None |
| ] |
| ) |
|
|
|
|
| |
| def query_gpu(indices): |
| os.system("nvidia-smi -i 0 --query-gpu=memory.free --format=csv") |
|
|
|
|
| |
| def count_parameters(module): |
| print( |
| "Number of parameters: {}".format( |
| sum([p.data.nelement() for p in module.parameters()]) |
| ) |
| ) |
|
|
|
|
| |
| def sample_1hot(batch_size, num_classes, device="cuda"): |
| return torch.randint( |
| low=0, |
| high=num_classes, |
| size=(batch_size,), |
| device=device, |
| dtype=torch.int64, |
| requires_grad=False, |
| ) |
|
|
|
|
| def initiate_standing_stats(net): |
| for module in net.modules(): |
| if hasattr(module, "accumulate_standing"): |
| module.reset_stats() |
| module.accumulate_standing = True |
|
|
|
|
| def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16): |
| initiate_standing_stats(net) |
| net.train() |
| for i in range(num_accumulations): |
| with torch.no_grad(): |
| z.normal_() |
| y.random_(0, nclasses) |
| x = net(z, net.shared(y)) |
| |
| net.eval() |
|
|
|
|
| |
| |
| |
| |
| |
| |
| import math |
| from torch.optim.optimizer import Optimizer |
|
|
|
|
| class Adam16(Optimizer): |
| def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): |
| defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) |
| params = list(params) |
| super(Adam16, self).__init__(params, defaults) |
|
|
| |
| def load_state_dict(self, state_dict): |
| super(Adam16, self).load_state_dict(state_dict) |
| for group in self.param_groups: |
| for p in group["params"]: |
| self.state[p]["exp_avg"] = self.state[p]["exp_avg"].float() |
| self.state[p]["exp_avg_sq"] = self.state[p]["exp_avg_sq"].float() |
| self.state[p]["fp32_p"] = self.state[p]["fp32_p"].float() |
|
|
| def step(self, closure=None): |
| """Performs a single optimization step. |
| Arguments: |
| closure (callable, optional): A closure that reevaluates the model |
| and returns the loss. |
| """ |
| loss = None |
| if closure is not None: |
| loss = closure() |
|
|
| for group in self.param_groups: |
| for p in group["params"]: |
| if p.grad is None: |
| continue |
|
|
| grad = p.grad.data.float() |
| state = self.state[p] |
|
|
| |
| if len(state) == 0: |
| state["step"] = 0 |
| |
| state["exp_avg"] = grad.new().resize_as_(grad).zero_() |
| |
| state["exp_avg_sq"] = grad.new().resize_as_(grad).zero_() |
| |
| state["fp32_p"] = p.data.float() |
|
|
| exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
| beta1, beta2 = group["betas"] |
|
|
| state["step"] += 1 |
|
|
| if group["weight_decay"] != 0: |
| grad = grad.add(group["weight_decay"], state["fp32_p"]) |
|
|
| |
| exp_avg.mul_(beta1).add_(1 - beta1, grad) |
| exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
|
|
| denom = exp_avg_sq.sqrt().add_(group["eps"]) |
|
|
| bias_correction1 = 1 - beta1 ** state["step"] |
| bias_correction2 = 1 - beta2 ** state["step"] |
| step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 |
|
|
| state["fp32_p"].addcdiv_(-step_size, exp_avg, denom) |
| p.data = state["fp32_p"].half() |
|
|
| return loss |
|
|