SupraMNST-IMG-200k / modeling.py
Harley-ml's picture
Update modeling.py
880c288 verified
#!/usr/bin/env python3
# Model for SupraMNiST-IMG-200k
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional
import torch
from diffusers import UNet2DConditionModel
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
from .configuration import DigitDiffusionConfig
@dataclass
class DigitDiffusionOutput(ModelOutput):
sample: torch.FloatTensor | None = None
class DigitDiffusionModel(PreTrainedModel):
config_class = DigitDiffusionConfig
base_model_prefix = "unet"
main_input_name = "noisy_images"
all_tied_weights_keys = {}
def __init__(self, config: DigitDiffusionConfig) -> None:
super().__init__(config)
block_count = len(config.block_out_channels)
self.unet = UNet2DConditionModel(
sample_size=config.sample_size,
in_channels=config.in_channels,
out_channels=config.out_channels,
layers_per_block=config.layers_per_block,
block_out_channels=tuple(config.block_out_channels),
down_block_types=("DownBlock2D",) * block_count,
up_block_types=("UpBlock2D",) * block_count,
mid_block_type="UNetMidBlock2D",
norm_num_groups=config.norm_num_groups,
num_class_embeds=config.num_classes,
cross_attention_dim=config.cross_attention_dim,
class_embed_type=config.class_embed_type,
)
self.post_init()
def _init_weights(self, module):
# Diffusers initializes the UNet internally, so there is nothing extra
# to initialize here.
return
def _make_dummy_context(
self,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
return torch.zeros(
batch_size,
1,
self.config.cross_attention_dim,
device=device,
dtype=dtype,
)
def _normalize_inputs(
self,
noisy_images: Optional[torch.Tensor] = None,
timesteps: Optional[torch.Tensor | int] = None,
sample: Optional[torch.Tensor] = None,
timestep: Optional[torch.Tensor | int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if noisy_images is None:
noisy_images = sample
if timesteps is None:
timesteps = timestep
if noisy_images is None:
raise ValueError("Either `noisy_images` or `sample` must be provided.")
if timesteps is None:
raise ValueError("Either `timesteps` or `timestep` must be provided.")
if not torch.is_tensor(timesteps):
timesteps = torch.tensor(
timesteps,
device=noisy_images.device,
dtype=torch.long,
)
if timesteps.ndim == 0:
timesteps = timesteps.expand(noisy_images.shape[0])
elif timesteps.shape[0] != noisy_images.shape[0]:
timesteps = timesteps.reshape(-1)
if timesteps.numel() == 1:
timesteps = timesteps.expand(noisy_images.shape[0])
elif timesteps.shape[0] != noisy_images.shape[0]:
raise ValueError(
"Timesteps must be a scalar, a batch-sized tensor, or a single-value tensor."
)
return noisy_images, timesteps.to(device=noisy_images.device, dtype=torch.long)
def forward(
self,
noisy_images: Optional[torch.Tensor] = None,
timesteps: Optional[torch.Tensor | int] = None,
class_labels: Optional[torch.Tensor] = None,
sample: Optional[torch.Tensor] = None,
timestep: Optional[torch.Tensor | int] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs: Any,
):
noisy_images, timesteps = self._normalize_inputs(
noisy_images=noisy_images,
timesteps=timesteps,
sample=sample,
timestep=timestep,
)
batch_size = noisy_images.shape[0]
if class_labels is None:
class_labels = torch.zeros(
batch_size,
device=noisy_images.device,
dtype=torch.long,
)
else:
class_labels = class_labels.to(device=noisy_images.device, dtype=torch.long)
if encoder_hidden_states is None:
encoder_hidden_states = self._make_dummy_context(
batch_size=batch_size,
device=noisy_images.device,
dtype=noisy_images.dtype,
)
noise_pred = self.unet(
sample=noisy_images,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
class_labels=class_labels,
return_dict=True,
**kwargs,
).sample
if return_dict:
return DigitDiffusionOutput(sample=noise_pred)
return (noise_pred,)
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
if state_dict:
keys = list(state_dict.keys())
has_prefixed = any(k.startswith("unet.") for k in keys)
has_plain_unet = any(
k.startswith(
(
"conv_in.",
"conv_norm_out.",
"conv_out.",
"time_embedding.",
"class_embedding.",
"down_blocks.",
"up_blocks.",
"mid_block.",
)
)
for k in keys
)
if has_plain_unet and not has_prefixed:
state_dict = {f"unet.{k}": v for k, v in state_dict.items()}
return super().load_state_dict(state_dict, strict=strict, assign=assign)
DigitDiffusionModel.register_for_auto_class("AutoModel")