File size: 1,785 Bytes
705a8fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | from . import gaussian_diffusion as gd_orig
from . import gaussian_diffusion_dual as gd_dual
# from .respace import SpacedDiffusion, space_timesteps
def create_diffusion(
timestep_respacing,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=True,
rescale_learned_sigmas=False,
diffusion_steps=1000,
dual=False
):
if dual:
print("Using DUAL diffusion")
from .respace_dual import SpacedDiffusion, space_timesteps
gd_module = gd_dual
else:
print("Using SINGLE diffusion")
from .respace import SpacedDiffusion, space_timesteps
gd_module = gd_orig
betas = gd_module.get_named_beta_schedule(noise_schedule, diffusion_steps)
# betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd_module.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd_module.LossType.RESCALED_MSE
else:
loss_type = gd_module.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd_module.ModelMeanType.EPSILON if not predict_xstart else gd_module.ModelMeanType.START_X
),
model_var_type=(
(
gd_module.ModelVarType.FIXED_LARGE
if not sigma_small
else gd_module.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd_module.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type
# rescale_timesteps=rescale_timesteps,
)
|