| from enum import Enum |
| from torch import nn |
|
|
|
|
| class TrainMode(Enum): |
| |
| manipulate = 'manipulate' |
| |
| diffusion = 'diffusion' |
| |
| |
| latent_diffusion = 'latentdiffusion' |
|
|
| def is_manipulate(self): |
| return self in [ |
| TrainMode.manipulate, |
| ] |
|
|
| def is_diffusion(self): |
| return self in [ |
| TrainMode.diffusion, |
| TrainMode.latent_diffusion, |
| ] |
|
|
| def is_autoenc(self): |
| |
| return self in [ |
| TrainMode.diffusion, |
| ] |
|
|
| def is_latent_diffusion(self): |
| return self in [ |
| TrainMode.latent_diffusion, |
| ] |
|
|
| def use_latent_net(self): |
| return self.is_latent_diffusion() |
|
|
| def require_dataset_infer(self): |
| """ |
| whether training in this mode requires the latent variables to be available? |
| """ |
| |
| |
| return self in [ |
| TrainMode.latent_diffusion, |
| TrainMode.manipulate, |
| ] |
|
|
|
|
| class ManipulateMode(Enum): |
| """ |
| how to train the classifier to manipulate |
| """ |
| |
| celebahq_all = 'celebahq_all' |
| |
| d2c_fewshot = 'd2cfewshot' |
| d2c_fewshot_allneg = 'd2cfewshotallneg' |
|
|
| def is_celeba_attr(self): |
| return self in [ |
| ManipulateMode.d2c_fewshot, |
| ManipulateMode.d2c_fewshot_allneg, |
| ManipulateMode.celebahq_all, |
| ] |
|
|
| def is_single_class(self): |
| return self in [ |
| ManipulateMode.d2c_fewshot, |
| ManipulateMode.d2c_fewshot_allneg, |
| ] |
|
|
| def is_fewshot(self): |
| return self in [ |
| ManipulateMode.d2c_fewshot, |
| ManipulateMode.d2c_fewshot_allneg, |
| ] |
|
|
| def is_fewshot_allneg(self): |
| return self in [ |
| ManipulateMode.d2c_fewshot_allneg, |
| ] |
|
|
|
|
| class ModelType(Enum): |
| """ |
| Kinds of the backbone models |
| """ |
|
|
| |
| ddpm = 'ddpm' |
| |
| autoencoder = 'autoencoder' |
|
|
| def has_autoenc(self): |
| return self in [ |
| ModelType.autoencoder, |
| ] |
|
|
| def can_sample(self): |
| return self in [ModelType.ddpm] |
|
|
|
|
| class ModelName(Enum): |
| """ |
| List of all supported model classes |
| """ |
|
|
| beatgans_ddpm = 'beatgans_ddpm' |
| beatgans_autoenc = 'beatgans_autoenc' |
|
|
|
|
| class ModelMeanType(Enum): |
| """ |
| Which type of output the model predicts. |
| """ |
|
|
| eps = 'eps' |
|
|
|
|
| class ModelVarType(Enum): |
| """ |
| What is used as the model's output variance. |
| |
| The LEARNED_RANGE option has been added to allow the model to predict |
| values between FIXED_SMALL and FIXED_LARGE, making its job easier. |
| """ |
|
|
| |
| fixed_small = 'fixed_small' |
| |
| fixed_large = 'fixed_large' |
|
|
|
|
| class LossType(Enum): |
| mse = 'mse' |
| l1 = 'l1' |
|
|
|
|
| class GenerativeType(Enum): |
| """ |
| How's a sample generated |
| """ |
|
|
| ddpm = 'ddpm' |
| ddim = 'ddim' |
|
|
|
|
| class OptimizerType(Enum): |
| adam = 'adam' |
| adamw = 'adamw' |
|
|
|
|
| class Activation(Enum): |
| none = 'none' |
| relu = 'relu' |
| lrelu = 'lrelu' |
| silu = 'silu' |
| tanh = 'tanh' |
|
|
| def get_act(self): |
| if self == Activation.none: |
| return nn.Identity() |
| elif self == Activation.relu: |
| return nn.ReLU() |
| elif self == Activation.lrelu: |
| return nn.LeakyReLU(negative_slope=0.2) |
| elif self == Activation.silu: |
| return nn.SiLU() |
| elif self == Activation.tanh: |
| return nn.Tanh() |
| else: |
| raise NotImplementedError() |
|
|
|
|
| class ManipulateLossType(Enum): |
| bce = 'bce' |
| mse = 'mse' |