| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Utility functions for creating schedules and samplers from config. |
| """ |
|
|
| import torch |
| from omegaconf import DictConfig |
|
|
| from .samplers.base import Sampler |
| from .samplers.euler import EulerSampler |
| from .schedules.base import Schedule |
| from .schedules.lerp import LinearInterpolationSchedule |
| from .timesteps.base import SamplingTimesteps |
| from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps |
|
|
|
|
| def create_schedule_from_config( |
| config: DictConfig, |
| device: torch.device, |
| dtype: torch.dtype = torch.float32, |
| ) -> Schedule: |
| """ |
| Create a schedule from configuration. |
| """ |
| if config.type == "lerp": |
| return LinearInterpolationSchedule(T=config.get("T", 1.0)) |
|
|
| raise NotImplementedError |
|
|
|
|
| def create_sampler_from_config( |
| config: DictConfig, |
| schedule: Schedule, |
| timesteps: SamplingTimesteps, |
| ) -> Sampler: |
| """ |
| Create a sampler from configuration. |
| """ |
| if config.type == "euler": |
| return EulerSampler( |
| schedule=schedule, |
| timesteps=timesteps, |
| prediction_type=config.prediction_type, |
| ) |
| raise NotImplementedError |
|
|
|
|
| def create_sampling_timesteps_from_config( |
| config: DictConfig, |
| schedule: Schedule, |
| device: torch.device, |
| dtype: torch.dtype = torch.float32, |
| ) -> SamplingTimesteps: |
| if config.type == "uniform_trailing": |
| return UniformTrailingSamplingTimesteps( |
| T=schedule.T, |
| steps=config.steps, |
| shift=config.get("shift", 1.0), |
| device=device, |
| ) |
| raise NotImplementedError |