#!/usr/bin/env python3 # Model for MNiST-IMG-390k from __future__ import annotations from dataclasses import dataclass from typing import Any, Optional import torch from diffusers import UNet2DConditionModel from transformers import PreTrainedModel from transformers.utils import ModelOutput from configuration import DigitDiffusionConfig @dataclass class DigitDiffusionOutput(ModelOutput): sample: torch.FloatTensor | None = None class DigitDiffusionModel(PreTrainedModel): config_class = DigitDiffusionConfig base_model_prefix = "unet" main_input_name = "noisy_images" all_tied_weights_keys = {} def __init__(self, config: DigitDiffusionConfig) -> None: super().__init__(config) block_count = len(config.block_out_channels) self.unet = UNet2DConditionModel( sample_size=config.sample_size, in_channels=config.in_channels, out_channels=config.out_channels, layers_per_block=config.layers_per_block, block_out_channels=tuple(config.block_out_channels), down_block_types=("DownBlock2D",) * block_count, up_block_types=("UpBlock2D",) * block_count, mid_block_type="UNetMidBlock2D", norm_num_groups=config.norm_num_groups, num_class_embeds=config.num_classes, cross_attention_dim=config.cross_attention_dim, class_embed_type=config.class_embed_type, ) self.post_init() def _init_weights(self, module): # Diffusers initializes the UNet internally, so there is nothing extra # to initialize here. return def _make_dummy_context( self, batch_size: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: return torch.zeros( batch_size, 1, self.config.cross_attention_dim, device=device, dtype=dtype, ) def _normalize_inputs( self, noisy_images: Optional[torch.Tensor] = None, timesteps: Optional[torch.Tensor | int] = None, sample: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor | int] = None, ) -> tuple[torch.Tensor, torch.Tensor]: if noisy_images is None: noisy_images = sample if timesteps is None: timesteps = timestep if noisy_images is None: raise ValueError("Either `noisy_images` or `sample` must be provided.") if timesteps is None: raise ValueError("Either `timesteps` or `timestep` must be provided.") if not torch.is_tensor(timesteps): timesteps = torch.tensor( timesteps, device=noisy_images.device, dtype=torch.long, ) if timesteps.ndim == 0: timesteps = timesteps.expand(noisy_images.shape[0]) elif timesteps.shape[0] != noisy_images.shape[0]: timesteps = timesteps.reshape(-1) if timesteps.numel() == 1: timesteps = timesteps.expand(noisy_images.shape[0]) elif timesteps.shape[0] != noisy_images.shape[0]: raise ValueError( "Timesteps must be a scalar, a batch-sized tensor, or a single-value tensor." ) return noisy_images, timesteps.to(device=noisy_images.device, dtype=torch.long) def forward( self, noisy_images: Optional[torch.Tensor] = None, timesteps: Optional[torch.Tensor | int] = None, class_labels: Optional[torch.Tensor] = None, sample: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor | int] = None, encoder_hidden_states: Optional[torch.Tensor] = None, return_dict: bool = True, **kwargs: Any, ): noisy_images, timesteps = self._normalize_inputs( noisy_images=noisy_images, timesteps=timesteps, sample=sample, timestep=timestep, ) batch_size = noisy_images.shape[0] if class_labels is None: class_labels = torch.zeros( batch_size, device=noisy_images.device, dtype=torch.long, ) else: class_labels = class_labels.to(device=noisy_images.device, dtype=torch.long) if encoder_hidden_states is None: encoder_hidden_states = self._make_dummy_context( batch_size=batch_size, device=noisy_images.device, dtype=noisy_images.dtype, ) noise_pred = self.unet( sample=noisy_images, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, class_labels=class_labels, return_dict=True, **kwargs, ).sample if return_dict: return DigitDiffusionOutput(sample=noise_pred) return (noise_pred,) def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): if state_dict: keys = list(state_dict.keys()) has_prefixed = any(k.startswith("unet.") for k in keys) has_plain_unet = any( k.startswith( ( "conv_in.", "conv_norm_out.", "conv_out.", "time_embedding.", "class_embedding.", "down_blocks.", "up_blocks.", "mid_block.", ) ) for k in keys ) if has_plain_unet and not has_prefixed: state_dict = {f"unet.{k}": v for k, v in state_dict.items()} return super().load_state_dict(state_dict, strict=strict, assign=assign) DigitDiffusionModel.register_for_auto_class("AutoModel")