| |
| |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Optional |
|
|
| import torch |
| from diffusers import UNet2DConditionModel |
| from transformers import PreTrainedModel |
| from transformers.utils import ModelOutput |
|
|
| from configuration import DigitDiffusionConfig |
|
|
|
|
| @dataclass |
| class DigitDiffusionOutput(ModelOutput): |
| sample: torch.FloatTensor | None = None |
|
|
|
|
| class DigitDiffusionModel(PreTrainedModel): |
| config_class = DigitDiffusionConfig |
| base_model_prefix = "unet" |
| main_input_name = "noisy_images" |
| all_tied_weights_keys = {} |
|
|
| def __init__(self, config: DigitDiffusionConfig) -> None: |
| super().__init__(config) |
|
|
| block_count = len(config.block_out_channels) |
|
|
| self.unet = UNet2DConditionModel( |
| sample_size=config.sample_size, |
| in_channels=config.in_channels, |
| out_channels=config.out_channels, |
| layers_per_block=config.layers_per_block, |
| block_out_channels=tuple(config.block_out_channels), |
| down_block_types=("DownBlock2D",) * block_count, |
| up_block_types=("UpBlock2D",) * block_count, |
| mid_block_type="UNetMidBlock2D", |
| norm_num_groups=config.norm_num_groups, |
| num_class_embeds=config.num_classes, |
| cross_attention_dim=config.cross_attention_dim, |
| class_embed_type=config.class_embed_type, |
| ) |
|
|
| self.post_init() |
|
|
| def _init_weights(self, module): |
| |
| |
| return |
|
|
| def _make_dummy_context( |
| self, |
| batch_size: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| return torch.zeros( |
| batch_size, |
| 1, |
| self.config.cross_attention_dim, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| def _normalize_inputs( |
| self, |
| noisy_images: Optional[torch.Tensor] = None, |
| timesteps: Optional[torch.Tensor | int] = None, |
| sample: Optional[torch.Tensor] = None, |
| timestep: Optional[torch.Tensor | int] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if noisy_images is None: |
| noisy_images = sample |
| if timesteps is None: |
| timesteps = timestep |
|
|
| if noisy_images is None: |
| raise ValueError("Either `noisy_images` or `sample` must be provided.") |
| if timesteps is None: |
| raise ValueError("Either `timesteps` or `timestep` must be provided.") |
|
|
| if not torch.is_tensor(timesteps): |
| timesteps = torch.tensor( |
| timesteps, |
| device=noisy_images.device, |
| dtype=torch.long, |
| ) |
| if timesteps.ndim == 0: |
| timesteps = timesteps.expand(noisy_images.shape[0]) |
| elif timesteps.shape[0] != noisy_images.shape[0]: |
| timesteps = timesteps.reshape(-1) |
| if timesteps.numel() == 1: |
| timesteps = timesteps.expand(noisy_images.shape[0]) |
| elif timesteps.shape[0] != noisy_images.shape[0]: |
| raise ValueError( |
| "Timesteps must be a scalar, a batch-sized tensor, or a single-value tensor." |
| ) |
|
|
| return noisy_images, timesteps.to(device=noisy_images.device, dtype=torch.long) |
|
|
| def forward( |
| self, |
| noisy_images: Optional[torch.Tensor] = None, |
| timesteps: Optional[torch.Tensor | int] = None, |
| class_labels: Optional[torch.Tensor] = None, |
| sample: Optional[torch.Tensor] = None, |
| timestep: Optional[torch.Tensor | int] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs: Any, |
| ): |
| noisy_images, timesteps = self._normalize_inputs( |
| noisy_images=noisy_images, |
| timesteps=timesteps, |
| sample=sample, |
| timestep=timestep, |
| ) |
|
|
| batch_size = noisy_images.shape[0] |
| if class_labels is None: |
| class_labels = torch.zeros( |
| batch_size, |
| device=noisy_images.device, |
| dtype=torch.long, |
| ) |
| else: |
| class_labels = class_labels.to(device=noisy_images.device, dtype=torch.long) |
|
|
| if encoder_hidden_states is None: |
| encoder_hidden_states = self._make_dummy_context( |
| batch_size=batch_size, |
| device=noisy_images.device, |
| dtype=noisy_images.dtype, |
| ) |
|
|
| noise_pred = self.unet( |
| sample=noisy_images, |
| timestep=timesteps, |
| encoder_hidden_states=encoder_hidden_states, |
| class_labels=class_labels, |
| return_dict=True, |
| **kwargs, |
| ).sample |
|
|
| if return_dict: |
| return DigitDiffusionOutput(sample=noise_pred) |
| return (noise_pred,) |
|
|
| def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): |
| if state_dict: |
| keys = list(state_dict.keys()) |
| has_prefixed = any(k.startswith("unet.") for k in keys) |
| has_plain_unet = any( |
| k.startswith( |
| ( |
| "conv_in.", |
| "conv_norm_out.", |
| "conv_out.", |
| "time_embedding.", |
| "class_embedding.", |
| "down_blocks.", |
| "up_blocks.", |
| "mid_block.", |
| ) |
| ) |
| for k in keys |
| ) |
|
|
| if has_plain_unet and not has_prefixed: |
| state_dict = {f"unet.{k}": v for k, v in state_dict.items()} |
|
|
| return super().load_state_dict(state_dict, strict=strict, assign=assign) |
|
|
|
|
| DigitDiffusionModel.register_for_auto_class("AutoModel") |
|
|