SupraMNST-IMG-200k / configuration.py
Harley-ml's picture
Update configuration.py
0b6a7dc verified
#!/usr/bin/env python3
# Configuration for the SupraMNiST-IMG-200k
from __future__ import annotations
from typing import Iterable, Tuple
from transformers import PretrainedConfig
class DigitDiffusionConfig(PretrainedConfig):
model_type = "digit_diffusion"
def __init__(
self,
image_size: int = 32,
in_channels: int = 1,
out_channels: int = 1,
num_classes: int = 10,
block_out_channels: Iterable[int] = (12, 16),
layers_per_block: int = 8,
norm_num_groups: int = 4,
cross_attention_dim: int = 8,
class_embed_type: str | None = None,
sample_size: int | None = None,
**kwargs,
) -> None:
image_size = int(image_size)
sample_size = int(sample_size) if sample_size is not None else image_size
block_out_channels = tuple(int(v) for v in block_out_channels)
if not block_out_channels:
raise ValueError("block_out_channels must contain at least one entry.")
if any(v <= 0 for v in block_out_channels):
raise ValueError("block_out_channels must contain only positive integers.")
if image_size <= 0:
raise ValueError("image_size must be a positive integer.")
if sample_size <= 0:
raise ValueError("sample_size must be a positive integer.")
if in_channels <= 0 or out_channels <= 0:
raise ValueError("in_channels and out_channels must be positive integers.")
if num_classes <= 0:
raise ValueError("num_classes must be a positive integer.")
if layers_per_block <= 0:
raise ValueError("layers_per_block must be a positive integer.")
if norm_num_groups <= 0:
raise ValueError("norm_num_groups must be a positive integer.")
if cross_attention_dim <= 0:
raise ValueError("cross_attention_dim must be a positive integer.")
self.image_size = image_size
self.sample_size = sample_size
self.in_channels = int(in_channels)
self.out_channels = int(out_channels)
self.num_classes = int(num_classes)
self.block_out_channels = block_out_channels
self.layers_per_block = int(layers_per_block)
self.norm_num_groups = int(norm_num_groups)
self.cross_attention_dim = int(cross_attention_dim)
self.class_embed_type = class_embed_type
kwargs.setdefault("architectures", ["DigitDiffusionModel"])
super().__init__(**kwargs)
@property
def num_blocks(self) -> int:
return len(self.block_out_channels)
def to_dict(self):
data = super().to_dict()
data["block_out_channels"] = list(self.block_out_channels)
return data
DigitDiffusionConfig.register_for_auto_class()