| |
| |
|
|
| from __future__ import annotations |
|
|
| from typing import Iterable, Tuple |
|
|
| from transformers import PretrainedConfig |
|
|
|
|
| class DigitDiffusionConfig(PretrainedConfig): |
|
|
| model_type = "digit_diffusion" |
|
|
| def __init__( |
| self, |
| image_size: int = 32, |
| in_channels: int = 1, |
| out_channels: int = 1, |
| num_classes: int = 10, |
| block_out_channels: Iterable[int] = (12, 16, 20), |
| layers_per_block: int = 8, |
| norm_num_groups: int = 4, |
| cross_attention_dim: int = 8, |
| class_embed_type: str | None = None, |
| sample_size: int | None = None, |
| **kwargs, |
| ) -> None: |
| image_size = int(image_size) |
| sample_size = int(sample_size) if sample_size is not None else image_size |
|
|
| block_out_channels = tuple(int(v) for v in block_out_channels) |
| if not block_out_channels: |
| raise ValueError("block_out_channels must contain at least one entry.") |
| if any(v <= 0 for v in block_out_channels): |
| raise ValueError("block_out_channels must contain only positive integers.") |
|
|
| if image_size <= 0: |
| raise ValueError("image_size must be a positive integer.") |
| if sample_size <= 0: |
| raise ValueError("sample_size must be a positive integer.") |
| if in_channels <= 0 or out_channels <= 0: |
| raise ValueError("in_channels and out_channels must be positive integers.") |
| if num_classes <= 0: |
| raise ValueError("num_classes must be a positive integer.") |
| if layers_per_block <= 0: |
| raise ValueError("layers_per_block must be a positive integer.") |
| if norm_num_groups <= 0: |
| raise ValueError("norm_num_groups must be a positive integer.") |
| if cross_attention_dim <= 0: |
| raise ValueError("cross_attention_dim must be a positive integer.") |
|
|
| self.image_size = image_size |
| self.sample_size = sample_size |
| self.in_channels = int(in_channels) |
| self.out_channels = int(out_channels) |
| self.num_classes = int(num_classes) |
| self.block_out_channels = block_out_channels |
| self.layers_per_block = int(layers_per_block) |
| self.norm_num_groups = int(norm_num_groups) |
| self.cross_attention_dim = int(cross_attention_dim) |
| self.class_embed_type = class_embed_type |
|
|
| |
| kwargs.setdefault("architectures", ["DigitDiffusionModel"]) |
|
|
| super().__init__(**kwargs) |
|
|
| @property |
| def num_blocks(self) -> int: |
| return len(self.block_out_channels) |
|
|
| def to_dict(self): |
| data = super().to_dict() |
| |
| data["block_out_channels"] = list(self.block_out_channels) |
| return data |
|
|
|
|
| DigitDiffusionConfig.register_for_auto_class() |
|
|