#!/usr/bin/env python3 #Configuration for the MNiST-IMG-390k from __future__ import annotations from typing import Iterable, Tuple from transformers import PretrainedConfig class DigitDiffusionConfig(PretrainedConfig): model_type = "digit_diffusion" def __init__( self, image_size: int = 32, in_channels: int = 1, out_channels: int = 1, num_classes: int = 10, block_out_channels: Iterable[int] = (12, 16, 20), layers_per_block: int = 8, norm_num_groups: int = 4, cross_attention_dim: int = 8, class_embed_type: str | None = None, sample_size: int | None = None, **kwargs, ) -> None: image_size = int(image_size) sample_size = int(sample_size) if sample_size is not None else image_size block_out_channels = tuple(int(v) for v in block_out_channels) if not block_out_channels: raise ValueError("block_out_channels must contain at least one entry.") if any(v <= 0 for v in block_out_channels): raise ValueError("block_out_channels must contain only positive integers.") if image_size <= 0: raise ValueError("image_size must be a positive integer.") if sample_size <= 0: raise ValueError("sample_size must be a positive integer.") if in_channels <= 0 or out_channels <= 0: raise ValueError("in_channels and out_channels must be positive integers.") if num_classes <= 0: raise ValueError("num_classes must be a positive integer.") if layers_per_block <= 0: raise ValueError("layers_per_block must be a positive integer.") if norm_num_groups <= 0: raise ValueError("norm_num_groups must be a positive integer.") if cross_attention_dim <= 0: raise ValueError("cross_attention_dim must be a positive integer.") self.image_size = image_size self.sample_size = sample_size self.in_channels = int(in_channels) self.out_channels = int(out_channels) self.num_classes = int(num_classes) self.block_out_channels = block_out_channels self.layers_per_block = int(layers_per_block) self.norm_num_groups = int(norm_num_groups) self.cross_attention_dim = int(cross_attention_dim) self.class_embed_type = class_embed_type # Handy for HF model pages and AutoClass loading. kwargs.setdefault("architectures", ["DigitDiffusionModel"]) super().__init__(**kwargs) @property def num_blocks(self) -> int: return len(self.block_out_channels) def to_dict(self): data = super().to_dict() # Keep the serialized values compact and JSON-friendly. data["block_out_channels"] = list(self.block_out_channels) return data DigitDiffusionConfig.register_for_auto_class()